Skip to content

Commit 8a9855e

Browse files
committed
feat(convert): implement struct to dict binding with enhanced type handling
1 parent 725e022 commit 8a9855e

File tree

3 files changed

+162
-2
lines changed

3 files changed

+162
-2
lines changed

python/cocoindex/convert.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def make_engine_value_decoder(
8989
if dst_is_any:
9090
if src_type_kind == "Union":
9191
return lambda value: value[1]
92-
if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES:
92+
if src_type_kind == "Struct":
93+
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
94+
if src_type_kind in TABLE_TYPES:
9395
raise ValueError(
9496
f"Missing type annotation for `{''.join(field_path)}`."
9597
f"It's required for {src_type_kind} type."
@@ -98,6 +100,18 @@ def make_engine_value_decoder(
98100

99101
dst_type_info = analyze_type_info(dst_annotation)
100102

103+
# Handle struct -> dict binding for explicit dict annotations
104+
if (
105+
src_type_kind == "Struct"
106+
and dst_type_info.kind == "KTable"
107+
and dst_type_info.elem_type
108+
and isinstance(dst_type_info.elem_type, tuple)
109+
and len(dst_type_info.elem_type) == 2
110+
and dst_type_info.elem_type[0] is str
111+
and dst_type_info.elem_type[1] is Any
112+
):
113+
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
114+
101115
if src_type_kind == "Union":
102116
dst_type_variants = (
103117
dst_type_info.union_variant_types
@@ -294,6 +308,34 @@ def make_closure_for_value(
294308
)
295309

296310

311+
def _make_engine_struct_to_dict_decoder(
312+
field_path: list[str],
313+
src_fields: list[dict[str, Any]],
314+
) -> Callable[[list[Any]], dict[str, Any]]:
315+
"""Make a decoder from engine field values to a Python dict."""
316+
317+
field_decoders = []
318+
for i, field_schema in enumerate(src_fields):
319+
field_name = field_schema["name"]
320+
field_path.append(f".{field_name}")
321+
field_decoder = make_engine_value_decoder(
322+
field_path,
323+
field_schema["type"],
324+
Any, # Use Any for recursive decoding
325+
)
326+
field_path.pop()
327+
field_decoders.append((field_name, field_decoder))
328+
329+
def decode_to_dict(values: list[Any]) -> dict[str, Any]:
330+
result = {}
331+
for i, (field_name, field_decoder) in enumerate(field_decoders):
332+
if i < len(values):
333+
result[field_name] = field_decoder(values[i])
334+
return result
335+
336+
return decode_to_dict
337+
338+
297339
def dump_engine_object(v: Any) -> Any:
298340
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
299341
if v is None:

python/cocoindex/tests/test_convert.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,3 +1229,115 @@ class MixedStruct:
12291229
annotated_float=2.0,
12301230
)
12311231
validate_full_roundtrip(instance, MixedStruct)
1232+
1233+
1234+
def test_roundtrip_struct_to_dict_binding() -> None:
1235+
"""Test struct -> dict binding with Any annotation."""
1236+
1237+
@dataclass
1238+
class SimpleStruct:
1239+
name: str
1240+
value: int
1241+
price: float
1242+
1243+
instance = SimpleStruct("test", 42, 3.14)
1244+
expected_dict = {"name": "test", "value": 42, "price": 3.14}
1245+
1246+
# Test Any annotation
1247+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, Any))
1248+
1249+
1250+
def test_roundtrip_struct_to_dict_explicit() -> None:
1251+
"""Test struct -> dict binding with explicit dict annotations."""
1252+
1253+
@dataclass
1254+
class Product:
1255+
id: str
1256+
name: str
1257+
price: float
1258+
active: bool
1259+
1260+
instance = Product("P1", "Widget", 29.99, True)
1261+
expected_dict = {"id": "P1", "name": "Widget", "price": 29.99, "active": True}
1262+
1263+
# Test explicit dict annotations
1264+
validate_full_roundtrip(
1265+
instance, Product, (expected_dict, dict), (expected_dict, dict[str, Any])
1266+
)
1267+
1268+
1269+
def test_roundtrip_struct_to_dict_with_none_annotation() -> None:
1270+
"""Test struct -> dict binding with None annotation."""
1271+
1272+
@dataclass
1273+
class Config:
1274+
host: str
1275+
port: int
1276+
debug: bool
1277+
1278+
instance = Config("localhost", 8080, True)
1279+
expected_dict = {"host": "localhost", "port": 8080, "debug": True}
1280+
1281+
# Test None annotation (should be treated as Any)
1282+
validate_full_roundtrip(instance, Config, (expected_dict, None))
1283+
1284+
1285+
def test_roundtrip_struct_to_dict_nested() -> None:
1286+
"""Test struct -> dict binding with nested structs."""
1287+
1288+
@dataclass
1289+
class Address:
1290+
street: str
1291+
city: str
1292+
1293+
@dataclass
1294+
class Person:
1295+
name: str
1296+
age: int
1297+
address: Address
1298+
1299+
address = Address("123 Main St", "Anytown")
1300+
person = Person("John", 30, address)
1301+
expected_dict = {
1302+
"name": "John",
1303+
"age": 30,
1304+
"address": {"street": "123 Main St", "city": "Anytown"},
1305+
}
1306+
1307+
# Test nested struct conversion
1308+
validate_full_roundtrip(person, Person, (expected_dict, dict[str, Any]))
1309+
1310+
1311+
def test_roundtrip_struct_to_dict_with_list() -> None:
1312+
"""Test struct -> dict binding with list fields."""
1313+
1314+
@dataclass
1315+
class Team:
1316+
name: str
1317+
members: list[str]
1318+
active: bool
1319+
1320+
instance = Team("Dev Team", ["Alice", "Bob", "Charlie"], True)
1321+
expected_dict = {
1322+
"name": "Dev Team",
1323+
"members": ["Alice", "Bob", "Charlie"],
1324+
"active": True,
1325+
}
1326+
1327+
validate_full_roundtrip(instance, Team, (expected_dict, dict))
1328+
1329+
1330+
def test_roundtrip_namedtuple_to_dict_binding() -> None:
1331+
"""Test NamedTuple -> dict binding."""
1332+
1333+
class Point(NamedTuple):
1334+
x: float
1335+
y: float
1336+
z: float
1337+
1338+
instance = Point(1.0, 2.0, 3.0)
1339+
expected_dict = {"x": 1.0, "y": 2.0, "z": 3.0}
1340+
1341+
validate_full_roundtrip(
1342+
instance, Point, (expected_dict, dict), (expected_dict, Any)
1343+
)

python/cocoindex/typing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,10 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
241241

242242
elif base_type is collections.abc.Mapping or base_type is dict:
243243
args = typing.get_args(t)
244-
elem_type = (args[0], args[1])
244+
if len(args) == 0: # Handle untyped dict
245+
elem_type = (str, Any)
246+
else:
247+
elem_type = (args[0], args[1])
245248
kind = "KTable"
246249
elif base_type in (types.UnionType, typing.Union):
247250
possible_types = typing.get_args(t)
@@ -282,6 +285,9 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
282285
kind = "OffsetDateTime"
283286
elif t is datetime.timedelta:
284287
kind = "TimeDelta"
288+
elif t is dict:
289+
elem_type = (str, Any)
290+
kind = "KTable"
285291
else:
286292
raise ValueError(f"type unsupported yet: {t}")
287293

0 commit comments

Comments
 (0)