Skip to content

Commit 3139e18

Browse files
authored
feat(convert): simplify uptyped conversion and cover more scenarios (#826)
* feat(convert): simplify uptyped conversion and cover more scenarios * test: add more cases
1 parent 93e2890 commit 3139e18

File tree

2 files changed

+138
-90
lines changed

2 files changed

+138
-90
lines changed

python/cocoindex/convert.py

Lines changed: 53 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def make_engine_value_decoder(
9595
field_path: list[str],
9696
src_type: dict[str, Any],
9797
dst_type_info: AnalyzedTypeInfo,
98+
for_key: bool = False,
9899
) -> Callable[[Any], Any]:
99100
"""
100101
Make a decoder from an engine value to a Python value.
@@ -123,6 +124,7 @@ def make_engine_value_decoder(
123124
field_path,
124125
src_type["fields"],
125126
dst_type_info,
127+
for_key=for_key,
126128
)
127129

128130
if src_type_kind in TABLE_TYPES:
@@ -131,18 +133,18 @@ def make_engine_value_decoder(
131133

132134
if src_type_kind == "LTable":
133135
if isinstance(dst_type_variant, AnalyzedAnyType):
134-
return _make_engine_ltable_to_list_dict_decoder(
135-
field_path, engine_fields_schema
136-
)
137-
if not isinstance(dst_type_variant, AnalyzedListType):
136+
dst_elem_type = Any
137+
elif isinstance(dst_type_variant, AnalyzedListType):
138+
dst_elem_type = dst_type_variant.elem_type
139+
else:
138140
raise ValueError(
139141
f"Type mismatch for `{''.join(field_path)}`: "
140142
f"declared `{dst_type_info.core_type}`, a list type expected"
141143
)
142144
row_decoder = make_engine_struct_decoder(
143145
field_path,
144146
engine_fields_schema,
145-
analyze_type_info(dst_type_variant.elem_type),
147+
analyze_type_info(dst_elem_type),
146148
)
147149

148150
def decode(value: Any) -> Any | None:
@@ -152,10 +154,11 @@ def decode(value: Any) -> Any | None:
152154

153155
elif src_type_kind == "KTable":
154156
if isinstance(dst_type_variant, AnalyzedAnyType):
155-
return _make_engine_ktable_to_dict_dict_decoder(
156-
field_path, engine_fields_schema
157-
)
158-
if not isinstance(dst_type_variant, AnalyzedDictType):
157+
key_type, value_type = Any, Any
158+
elif isinstance(dst_type_variant, AnalyzedDictType):
159+
key_type = dst_type_variant.key_type
160+
value_type = dst_type_variant.value_type
161+
else:
159162
raise ValueError(
160163
f"Type mismatch for `{''.join(field_path)}`: "
161164
f"declared `{dst_type_info.core_type}`, a dict type expected"
@@ -166,13 +169,14 @@ def decode(value: Any) -> Any | None:
166169
key_decoder = make_engine_value_decoder(
167170
field_path,
168171
key_field_schema["type"],
169-
analyze_type_info(dst_type_variant.key_type),
172+
analyze_type_info(key_type),
173+
for_key=True,
170174
)
171175
field_path.pop()
172176
value_decoder = make_engine_struct_decoder(
173177
field_path,
174178
engine_fields_schema[1:],
175-
analyze_type_info(dst_type_variant.value_type),
179+
analyze_type_info(value_type),
176180
)
177181

178182
def decode(value: Any) -> Any | None:
@@ -316,26 +320,26 @@ def make_engine_struct_decoder(
316320
field_path: list[str],
317321
src_fields: list[dict[str, Any]],
318322
dst_type_info: AnalyzedTypeInfo,
323+
for_key: bool = False,
319324
) -> Callable[[list[Any]], Any]:
320325
"""Make a decoder from an engine field values to a Python value."""
321326

322327
dst_type_variant = dst_type_info.variant
323328

324-
use_dict = False
325329
if isinstance(dst_type_variant, AnalyzedAnyType):
326-
use_dict = True
330+
if for_key:
331+
return _make_engine_struct_to_tuple_decoder(field_path, src_fields)
332+
else:
333+
return _make_engine_struct_to_dict_decoder(field_path, src_fields, Any)
327334
elif isinstance(dst_type_variant, AnalyzedDictType):
328335
analyzed_key_type = analyze_type_info(dst_type_variant.key_type)
329-
analyzed_value_type = analyze_type_info(dst_type_variant.value_type)
330-
use_dict = (
336+
if (
331337
isinstance(analyzed_key_type.variant, AnalyzedAnyType)
332-
or (
333-
isinstance(analyzed_key_type.variant, AnalyzedBasicType)
334-
and analyzed_key_type.variant.kind == "Str"
338+
or analyzed_key_type.core_type is str
339+
):
340+
return _make_engine_struct_to_dict_decoder(
341+
field_path, src_fields, dst_type_variant.value_type
335342
)
336-
) and isinstance(analyzed_value_type.variant, AnalyzedAnyType)
337-
if use_dict:
338-
return _make_engine_struct_to_dict_decoder(field_path, src_fields)
339343

340344
if not isinstance(dst_type_variant, AnalyzedStructType):
341345
raise ValueError(
@@ -375,7 +379,7 @@ def make_closure_for_field(
375379
with ChildFieldPath(field_path, f".{name}"):
376380
if src_idx is not None:
377381
field_decoder = make_engine_value_decoder(
378-
field_path, src_fields[src_idx]["type"], type_info
382+
field_path, src_fields[src_idx]["type"], type_info, for_key=for_key
379383
)
380384
return lambda values: field_decoder(values[src_idx])
381385

@@ -409,17 +413,19 @@ def make_closure_for_field(
409413
def _make_engine_struct_to_dict_decoder(
410414
field_path: list[str],
411415
src_fields: list[dict[str, Any]],
416+
value_type_annotation: Any,
412417
) -> Callable[[list[Any] | None], dict[str, Any] | None]:
413418
"""Make a decoder from engine field values to a Python dict."""
414419

415420
field_decoders = []
416-
for i, field_schema in enumerate(src_fields):
421+
value_type_info = analyze_type_info(value_type_annotation)
422+
for field_schema in src_fields:
417423
field_name = field_schema["name"]
418424
with ChildFieldPath(field_path, f".{field_name}"):
419425
field_decoder = make_engine_value_decoder(
420426
field_path,
421427
field_schema["type"],
422-
analyze_type_info(Any), # Use Any for recursive decoding
428+
value_type_info,
423429
)
424430
field_decoders.append((field_name, field_decoder))
425431

@@ -438,76 +444,37 @@ def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
438444
return decode_to_dict
439445

440446

441-
def _make_engine_ltable_to_list_dict_decoder(
447+
def _make_engine_struct_to_tuple_decoder(
442448
field_path: list[str],
443449
src_fields: list[dict[str, Any]],
444-
) -> Callable[[list[Any] | None], list[dict[str, Any]] | None]:
445-
"""Make a decoder from engine LTable values to a list of dicts."""
446-
447-
# Create a decoder for each row (struct) to dict
448-
row_decoder = _make_engine_struct_to_dict_decoder(field_path, src_fields)
450+
) -> Callable[[list[Any] | None], tuple[Any, ...] | None]:
451+
"""Make a decoder from engine field values to a Python tuple."""
449452

450-
def decode_to_list_dict(values: list[Any] | None) -> list[dict[str, Any]] | None:
451-
if values is None:
452-
return None
453-
result = []
454-
for i, row_values in enumerate(values):
455-
decoded_row = row_decoder(row_values)
456-
if decoded_row is None:
457-
raise ValueError(
458-
f"LTable row at index {i} decoded to None, which is not allowed."
453+
field_decoders = []
454+
value_type_info = analyze_type_info(Any)
455+
for field_schema in src_fields:
456+
field_name = field_schema["name"]
457+
with ChildFieldPath(field_path, f".{field_name}"):
458+
field_decoders.append(
459+
make_engine_value_decoder(
460+
field_path,
461+
field_schema["type"],
462+
value_type_info,
459463
)
460-
result.append(decoded_row)
461-
return result
462-
463-
return decode_to_list_dict
464-
465-
466-
def _make_engine_ktable_to_dict_dict_decoder(
467-
field_path: list[str],
468-
src_fields: list[dict[str, Any]],
469-
) -> Callable[[list[Any] | None], dict[Any, dict[str, Any]] | None]:
470-
"""Make a decoder from engine KTable values to a dict of dicts."""
471-
472-
if not src_fields:
473-
raise ValueError("KTable must have at least one field for the key")
474-
475-
# First field is the key, remaining fields are the value
476-
key_field_schema = src_fields[0]
477-
value_fields_schema = src_fields[1:]
478-
479-
# Create decoders
480-
with ChildFieldPath(field_path, f".{key_field_schema.get('name', KEY_FIELD_NAME)}"):
481-
key_decoder = make_engine_value_decoder(
482-
field_path, key_field_schema["type"], analyze_type_info(Any)
483-
)
484-
485-
value_decoder = _make_engine_struct_to_dict_decoder(field_path, value_fields_schema)
464+
)
486465

487-
def decode_to_dict_dict(
488-
values: list[Any] | None,
489-
) -> dict[Any, dict[str, Any]] | None:
466+
def decode_to_tuple(values: list[Any] | None) -> tuple[Any, ...] | None:
490467
if values is None:
491468
return None
492-
result = {}
493-
for row_values in values:
494-
if not row_values:
495-
raise ValueError("KTable row must have at least 1 value (the key)")
496-
key = key_decoder(row_values[0])
497-
if len(row_values) == 1:
498-
value: dict[str, Any] = {}
499-
else:
500-
tmp = value_decoder(row_values[1:])
501-
if tmp is None:
502-
value = {}
503-
else:
504-
value = tmp
505-
if isinstance(key, dict):
506-
key = tuple(key.values())
507-
result[key] = value
508-
return result
469+
if len(field_decoders) != len(values):
470+
raise ValueError(
471+
f"Field count mismatch: expected {len(field_decoders)}, got {len(values)}"
472+
)
473+
return tuple(
474+
field_decoder(value) for value, field_decoder in zip(values, field_decoders)
475+
)
509476

510-
return decode_to_dict_dict
477+
return decode_to_tuple
511478

512479

513480
def dump_engine_object(v: Any) -> Any:

python/cocoindex/tests/test_convert.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,37 @@ class MixedStruct:
11601160
validate_full_roundtrip(instance, MixedStruct)
11611161

11621162

1163+
def test_roundtrip_simple_struct_to_dict_binding() -> None:
1164+
"""Test struct -> dict binding with Any annotation."""
1165+
1166+
@dataclass
1167+
class SimpleStruct:
1168+
first_name: str
1169+
last_name: str
1170+
1171+
instance = SimpleStruct("John", "Doe")
1172+
expected_dict = {"first_name": "John", "last_name": "Doe"}
1173+
1174+
# Test Any annotation
1175+
validate_full_roundtrip(
1176+
instance,
1177+
SimpleStruct,
1178+
(expected_dict, Any),
1179+
(expected_dict, dict),
1180+
(expected_dict, dict[Any, Any]),
1181+
(expected_dict, dict[str, Any]),
1182+
# For simple struct, all fields have the same type, so we can directly use the type as the dict value type.
1183+
(expected_dict, dict[Any, str]),
1184+
(expected_dict, dict[str, str]),
1185+
)
1186+
1187+
with pytest.raises(ValueError):
1188+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, int]))
1189+
1190+
with pytest.raises(ValueError):
1191+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
1192+
1193+
11631194
def test_roundtrip_struct_to_dict_binding() -> None:
11641195
"""Test struct -> dict binding with Any annotation."""
11651196

@@ -1173,7 +1204,20 @@ class SimpleStruct:
11731204
expected_dict = {"name": "test", "value": 42, "price": 3.14}
11741205

11751206
# Test Any annotation
1176-
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, Any))
1207+
validate_full_roundtrip(
1208+
instance,
1209+
SimpleStruct,
1210+
(expected_dict, Any),
1211+
(expected_dict, dict),
1212+
(expected_dict, dict[Any, Any]),
1213+
(expected_dict, dict[str, Any]),
1214+
)
1215+
1216+
with pytest.raises(ValueError):
1217+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[str, str]))
1218+
1219+
with pytest.raises(ValueError):
1220+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, dict[int, Any]))
11771221

11781222

11791223
def test_roundtrip_struct_to_dict_explicit() -> None:
@@ -1289,7 +1333,13 @@ class User:
12891333
]
12901334

12911335
# Test Any annotation
1292-
validate_full_roundtrip(users, list[User], (expected_list_dict, Any))
1336+
validate_full_roundtrip(
1337+
users,
1338+
list[User],
1339+
(expected_list_dict, Any),
1340+
(expected_list_dict, list[Any]),
1341+
(expected_list_dict, list[dict[str, Any]]),
1342+
)
12931343

12941344

12951345
def test_roundtrip_ktable_to_dict_dict_binding() -> None:
@@ -1313,7 +1363,17 @@ class Product:
13131363
}
13141364

13151365
# Test Any annotation
1316-
validate_full_roundtrip(products, dict[str, Product], (expected_dict_dict, Any))
1366+
validate_full_roundtrip(
1367+
products,
1368+
dict[str, Product],
1369+
(expected_dict_dict, Any),
1370+
(expected_dict_dict, dict),
1371+
(expected_dict_dict, dict[Any, Any]),
1372+
(expected_dict_dict, dict[str, Any]),
1373+
(expected_dict_dict, dict[Any, dict[Any, Any]]),
1374+
(expected_dict_dict, dict[str, dict[Any, Any]]),
1375+
(expected_dict_dict, dict[str, dict[str, Any]]),
1376+
)
13171377

13181378

13191379
def test_roundtrip_ktable_with_complex_key() -> None:
@@ -1339,7 +1399,28 @@ class Order:
13391399
}
13401400

13411401
# Test Any annotation
1342-
validate_full_roundtrip(orders, dict[OrderKey, Order], (expected_dict_dict, Any))
1402+
validate_full_roundtrip(
1403+
orders,
1404+
dict[OrderKey, Order],
1405+
(expected_dict_dict, Any),
1406+
(expected_dict_dict, dict),
1407+
(expected_dict_dict, dict[Any, Any]),
1408+
(expected_dict_dict, dict[Any, dict[str, Any]]),
1409+
(
1410+
{
1411+
("shop1", 1): Order("Alice", 100.0),
1412+
("shop2", 2): Order("Bob", 200.0),
1413+
},
1414+
dict[Any, Order],
1415+
),
1416+
(
1417+
{
1418+
OrderKey("shop1", 1): {"customer": "Alice", "total": 100.0},
1419+
OrderKey("shop2", 2): {"customer": "Bob", "total": 200.0},
1420+
},
1421+
dict[OrderKey, Any],
1422+
),
1423+
)
13431424

13441425

13451426
def test_roundtrip_ltable_with_nested_structs() -> None:

0 commit comments

Comments
 (0)