Skip to content

Commit e1a777a

Browse files
committed
feat(convert): enhance dict and struct encoding with type awareness
1 parent 725e022 commit e1a777a

File tree

2 files changed

+173
-14
lines changed

2 files changed

+173
-14
lines changed

python/cocoindex/convert.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,42 @@
2222
)
2323

2424

25+
def _is_ktable_dict(value: dict[Any, Any]) -> bool:
26+
"""Check if a dict is being used as a KTable (dict[key, struct])."""
27+
if not value:
28+
return False
29+
first_val = next(iter(value.values()))
30+
return is_struct_type(type(first_val))
31+
32+
33+
def encode_engine_value_with_type(value: Any, target_type: Any = None) -> Any:
34+
"""Encode a Python value to an engine value with optional type awareness."""
35+
# Handle dict-to-struct conversion when target type is a struct
36+
if (isinstance(value, dict) and
37+
target_type is not None and
38+
not _is_ktable_dict(value)):
39+
40+
from .typing import analyze_type_info
41+
try:
42+
type_info = analyze_type_info(target_type)
43+
if type_info.kind == "Struct" and type_info.struct_type:
44+
# Convert dict to struct format
45+
if dataclasses.is_dataclass(type_info.struct_type):
46+
# Extract values in dataclass field order
47+
fields = dataclasses.fields(type_info.struct_type)
48+
return [encode_engine_value_with_type(value.get(f.name), f.type) for f in fields]
49+
elif is_namedtuple_type(type_info.struct_type):
50+
# Extract values in namedtuple field order
51+
field_names = getattr(type_info.struct_type, "_fields", ())
52+
return [encode_engine_value_with_type(value.get(name),
53+
type_info.struct_type.__annotations__.get(name)) for name in field_names]
54+
except:
55+
# If type analysis fails, fall back to regular encoding
56+
pass
57+
58+
return encode_engine_value(value)
59+
60+
2561
def encode_engine_value(value: Any) -> Any:
2662
"""Encode a Python value to an engine value."""
2763
if dataclasses.is_dataclass(value):
@@ -37,16 +73,11 @@ def encode_engine_value(value: Any) -> Any:
3773
return value
3874
if isinstance(value, (list, tuple)):
3975
return [encode_engine_value(v) for v in value]
40-
if isinstance(value, dict):
41-
if not value:
42-
return {}
43-
44-
first_val = next(iter(value.values()))
45-
if is_struct_type(type(first_val)): # KTable
46-
return [
47-
[encode_engine_value(k)] + encode_engine_value(v)
48-
for k, v in value.items()
49-
]
76+
if isinstance(value, dict) and _is_ktable_dict(value):
77+
return [
78+
[encode_engine_value(k)] + encode_engine_value(v)
79+
for k, v in value.items()
80+
]
5081
return value
5182

5283

@@ -258,6 +289,17 @@ def _make_engine_struct_value_decoder(
258289
)
259290
for name in fields
260291
}
292+
elif dst_struct_type is dict or (hasattr(dst_struct_type, "__origin__") and
293+
getattr(dst_struct_type, "__origin__") is dict):
294+
parameters = {
295+
f["name"]: inspect.Parameter(
296+
name=f["name"],
297+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
298+
default=inspect.Parameter.empty,
299+
annotation=Any,
300+
)
301+
for f in src_fields
302+
}
261303
else:
262304
raise ValueError(f"Unsupported struct type: {dst_struct_type}")
263305

@@ -288,10 +330,18 @@ def make_closure_for_value(
288330
field_value_decoder = [
289331
make_closure_for_value(name, param) for (name, param) in parameters.items()
290332
]
291-
292-
return lambda values: dst_struct_type(
293-
*(decoder(values) for decoder in field_value_decoder)
294-
)
333+
if dst_struct_type is dict or (hasattr(dst_struct_type, "__origin__") and
334+
getattr(dst_struct_type, "__origin__") is dict):
335+
def dict_decoder(values: list[Any]) -> dict[str, Any]:
336+
return {
337+
name: decoder(values)
338+
for (name, _), decoder in zip(parameters.items(), field_value_decoder)
339+
}
340+
return dict_decoder
341+
else:
342+
return lambda values: dst_struct_type(
343+
*(decoder(values) for decoder in field_value_decoder)
344+
)
295345

296346

297347
def dump_engine_object(v: Any) -> Any:

python/cocoindex/tests/test_convert.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,3 +1229,112 @@ class MixedStruct:
12291229
annotated_float=2.0,
12301230
)
12311231
validate_full_roundtrip(instance, MixedStruct)
1232+
1233+
1234+
def test_dict_struct_encoding() -> None:
1235+
# Test encoding dict as struct - this should only happen in type-aware context
1236+
# For generic encoding, dict is passed through as-is
1237+
from cocoindex.convert import encode_engine_value
1238+
dict_value = {"name": "Alice", "age": 30, "city": "New York"}
1239+
result = encode_engine_value(dict_value)
1240+
assert result == {"name": "Alice", "age": 30, "city": "New York"} # Dict preserved
1241+
1242+
1243+
def test_dict_struct_decoding() -> None:
1244+
# Test decoding to dict as struct
1245+
from cocoindex.convert import _make_engine_struct_value_decoder
1246+
from typing import Dict, Any
1247+
1248+
src_fields = [
1249+
{"name": "name", "type": {"kind": "Str"}},
1250+
{"name": "age", "type": {"kind": "Int64"}},
1251+
{"name": "city", "type": {"kind": "Str"}},
1252+
]
1253+
1254+
decoder = _make_engine_struct_value_decoder([], src_fields, Dict[str, Any])
1255+
result = decoder(["Alice", 30, "New York"])
1256+
assert result == {"name": "Alice", "age": 30, "city": "New York"}
1257+
1258+
1259+
def test_dict_untyped_struct_decoding() -> None:
1260+
# Test decoding to untyped dict as struct
1261+
from cocoindex.convert import _make_engine_struct_value_decoder
1262+
1263+
src_fields = [
1264+
{"name": "name", "type": {"kind": "Str"}},
1265+
{"name": "age", "type": {"kind": "Int64"}},
1266+
{"name": "city", "type": {"kind": "Str"}},
1267+
]
1268+
1269+
decoder = _make_engine_struct_value_decoder([], src_fields, dict)
1270+
result = decoder(["Alice", 30, "New York"])
1271+
assert result == {"name": "Alice", "age": 30, "city": "New York"}
1272+
1273+
1274+
def test_dict_struct_vs_ktable() -> None:
1275+
# Test that dict-as-struct and dict-as-KTable are handled differently
1276+
from cocoindex.convert import encode_engine_value, _is_ktable_dict
1277+
1278+
# Dict as struct (simple values) - should be passed through in generic encoding
1279+
dict_struct = {"name": "Alice", "age": 30}
1280+
assert not _is_ktable_dict(dict_struct)
1281+
struct_result = encode_engine_value(dict_struct)
1282+
assert struct_result == {"name": "Alice", "age": 30} # Preserved as dict
1283+
1284+
# Dict as KTable (struct values)
1285+
@dataclass
1286+
class Person:
1287+
name: str
1288+
age: int
1289+
1290+
dict_ktable = {"p1": Person("Alice", 30), "p2": Person("Bob", 25)}
1291+
assert _is_ktable_dict(dict_ktable)
1292+
ktable_result = encode_engine_value(dict_ktable)
1293+
assert ktable_result == [["p1", "Alice", 30], ["p2", "Bob", 25]]
1294+
1295+
1296+
def test_dict_to_struct_conversion() -> None:
1297+
# Test dict to struct conversion with existing struct types
1298+
from cocoindex.convert import encode_engine_value_with_type
1299+
from dataclasses import dataclass
1300+
1301+
@dataclass
1302+
class Person:
1303+
name: str
1304+
age: int
1305+
city: str
1306+
1307+
dict_value = {"name": "Alice", "age": 30, "city": "New York"}
1308+
1309+
# Test encoding dict as struct
1310+
result = encode_engine_value_with_type(dict_value, Person)
1311+
assert result == ["Alice", 30, "New York"]
1312+
1313+
1314+
def test_struct_to_dict_conversion() -> None:
1315+
# Test struct to dict conversion
1316+
from cocoindex.convert import _make_engine_struct_value_decoder
1317+
from dataclasses import dataclass
1318+
from typing import Dict, Any
1319+
1320+
@dataclass
1321+
class Person:
1322+
name: str
1323+
age: int
1324+
city: str
1325+
1326+
src_fields = [
1327+
{"name": "name", "type": {"kind": "Str"}},
1328+
{"name": "age", "type": {"kind": "Int64"}},
1329+
{"name": "city", "type": {"kind": "Str"}},
1330+
]
1331+
1332+
# Test decoding struct to dict
1333+
decoder = _make_engine_struct_value_decoder([], src_fields, Dict[str, Any])
1334+
result = decoder(["Alice", 30, "New York"])
1335+
assert result == {"name": "Alice", "age": 30, "city": "New York"}
1336+
1337+
# Test decoding struct to untyped dict
1338+
decoder = _make_engine_struct_value_decoder([], src_fields, dict)
1339+
result = decoder(["Alice", 30, "New York"])
1340+
assert result == {"name": "Alice", "age": 30, "city": "New York"}

0 commit comments

Comments
 (0)