Skip to content

Commit 9740124

Browse files
authored
feat(convert): enhance dict and struct encoding with type awareness (#727)
* feat(convert): implement struct to dict binding with enhanced type handling * feat(convert): enhance struct to dict binding and improve type validation for annotations * refactor(convert): reorganize struct to dict binding logic for clarity and maintainability * remove unnecessary whitespace for ruff check
1 parent c40960c commit 9740124

File tree

3 files changed

+171
-3
lines changed

3 files changed

+171
-3
lines changed

python/cocoindex/convert.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,26 @@ def make_engine_value_decoder(
8989
if dst_is_any:
9090
if src_type_kind == "Union":
9191
return lambda value: value[1]
92-
if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES:
92+
if src_type_kind == "Struct":
93+
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
94+
if src_type_kind in TABLE_TYPES:
9395
raise ValueError(
9496
f"Missing type annotation for `{''.join(field_path)}`."
9597
f"It's required for {src_type_kind} type."
9698
)
9799
return lambda value: value
98100

101+
# Handle struct -> dict binding for explicit dict annotations
102+
is_dict_annotation = False
103+
if dst_annotation is dict:
104+
is_dict_annotation = True
105+
elif getattr(dst_annotation, "__origin__", None) is dict:
106+
args = getattr(dst_annotation, "__args__", ())
107+
if args == (str, Any):
108+
is_dict_annotation = True
109+
if is_dict_annotation and src_type_kind == "Struct":
110+
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
111+
99112
dst_type_info = analyze_type_info(dst_annotation)
100113

101114
if src_type_kind == "Union":
@@ -294,6 +307,39 @@ def make_closure_for_value(
294307
)
295308

296309

310+
def _make_engine_struct_to_dict_decoder(
311+
field_path: list[str],
312+
src_fields: list[dict[str, Any]],
313+
) -> Callable[[list[Any] | None], dict[str, Any] | None]:
314+
"""Make a decoder from engine field values to a Python dict."""
315+
316+
field_decoders = []
317+
for i, field_schema in enumerate(src_fields):
318+
field_name = field_schema["name"]
319+
field_path.append(f".{field_name}")
320+
field_decoder = make_engine_value_decoder(
321+
field_path,
322+
field_schema["type"],
323+
Any, # Use Any for recursive decoding
324+
)
325+
field_path.pop()
326+
field_decoders.append((field_name, field_decoder))
327+
328+
def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
329+
if values is None:
330+
return None
331+
if len(field_decoders) != len(values):
332+
raise ValueError(
333+
f"Field count mismatch: expected {len(field_decoders)}, got {len(values)}"
334+
)
335+
return {
336+
field_name: field_decoder(value)
337+
for value, (field_name, field_decoder) in zip(values, field_decoders)
338+
}
339+
340+
return decode_to_dict
341+
342+
297343
def dump_engine_object(v: Any) -> Any:
298344
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
299345
if v is None:

python/cocoindex/tests/test_convert.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,3 +1229,115 @@ class MixedStruct:
12291229
annotated_float=2.0,
12301230
)
12311231
validate_full_roundtrip(instance, MixedStruct)
1232+
1233+
1234+
def test_roundtrip_struct_to_dict_binding() -> None:
1235+
"""Test struct -> dict binding with Any annotation."""
1236+
1237+
@dataclass
1238+
class SimpleStruct:
1239+
name: str
1240+
value: int
1241+
price: float
1242+
1243+
instance = SimpleStruct("test", 42, 3.14)
1244+
expected_dict = {"name": "test", "value": 42, "price": 3.14}
1245+
1246+
# Test Any annotation
1247+
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, Any))
1248+
1249+
1250+
def test_roundtrip_struct_to_dict_explicit() -> None:
1251+
"""Test struct -> dict binding with explicit dict annotations."""
1252+
1253+
@dataclass
1254+
class Product:
1255+
id: str
1256+
name: str
1257+
price: float
1258+
active: bool
1259+
1260+
instance = Product("P1", "Widget", 29.99, True)
1261+
expected_dict = {"id": "P1", "name": "Widget", "price": 29.99, "active": True}
1262+
1263+
# Test explicit dict annotations
1264+
validate_full_roundtrip(
1265+
instance, Product, (expected_dict, dict), (expected_dict, dict[str, Any])
1266+
)
1267+
1268+
1269+
def test_roundtrip_struct_to_dict_with_none_annotation() -> None:
1270+
"""Test struct -> dict binding with None annotation."""
1271+
1272+
@dataclass
1273+
class Config:
1274+
host: str
1275+
port: int
1276+
debug: bool
1277+
1278+
instance = Config("localhost", 8080, True)
1279+
expected_dict = {"host": "localhost", "port": 8080, "debug": True}
1280+
1281+
# Test None annotation (should be treated as Any)
1282+
validate_full_roundtrip(instance, Config, (expected_dict, None))
1283+
1284+
1285+
def test_roundtrip_struct_to_dict_nested() -> None:
1286+
"""Test struct -> dict binding with nested structs."""
1287+
1288+
@dataclass
1289+
class Address:
1290+
street: str
1291+
city: str
1292+
1293+
@dataclass
1294+
class Person:
1295+
name: str
1296+
age: int
1297+
address: Address
1298+
1299+
address = Address("123 Main St", "Anytown")
1300+
person = Person("John", 30, address)
1301+
expected_dict = {
1302+
"name": "John",
1303+
"age": 30,
1304+
"address": {"street": "123 Main St", "city": "Anytown"},
1305+
}
1306+
1307+
# Test nested struct conversion
1308+
validate_full_roundtrip(person, Person, (expected_dict, dict[str, Any]))
1309+
1310+
1311+
def test_roundtrip_struct_to_dict_with_list() -> None:
1312+
"""Test struct -> dict binding with list fields."""
1313+
1314+
@dataclass
1315+
class Team:
1316+
name: str
1317+
members: list[str]
1318+
active: bool
1319+
1320+
instance = Team("Dev Team", ["Alice", "Bob", "Charlie"], True)
1321+
expected_dict = {
1322+
"name": "Dev Team",
1323+
"members": ["Alice", "Bob", "Charlie"],
1324+
"active": True,
1325+
}
1326+
1327+
validate_full_roundtrip(instance, Team, (expected_dict, dict))
1328+
1329+
1330+
def test_roundtrip_namedtuple_to_dict_binding() -> None:
1331+
"""Test NamedTuple -> dict binding."""
1332+
1333+
class Point(NamedTuple):
1334+
x: float
1335+
y: float
1336+
z: float
1337+
1338+
instance = Point(1.0, 2.0, 3.0)
1339+
expected_dict = {"x": 1.0, "y": 2.0, "z": 3.0}
1340+
1341+
validate_full_roundtrip(
1342+
instance, Point, (expected_dict, dict), (expected_dict, Any)
1343+
)

python/cocoindex/typing.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ class AnalyzedTypeInfo:
168168

169169
def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
170170
"""
171-
Analyze a Python type and return the analyzed info.
171+
Analyze a Python type annotation and extract CocoIndex-specific type information.
172+
Only concrete CocoIndex type annotations are supported. Raises ValueError for Any, empty, or untyped dict types.
172173
"""
173174
if isinstance(t, tuple) and len(t) == 2:
174175
kt, vt = t
@@ -241,7 +242,12 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
241242

242243
elif base_type is collections.abc.Mapping or base_type is dict:
243244
args = typing.get_args(t)
244-
elem_type = (args[0], args[1])
245+
if len(args) == 0: # Handle untyped dict
246+
raise ValueError(
247+
"Untyped dict is not supported; please provide a concrete type, e.g., dict[str, Any]."
248+
)
249+
else:
250+
elem_type = (args[0], args[1])
245251
kind = "KTable"
246252
elif base_type in (types.UnionType, typing.Union):
247253
possible_types = typing.get_args(t)
@@ -282,6 +288,10 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
282288
kind = "OffsetDateTime"
283289
elif t is datetime.timedelta:
284290
kind = "TimeDelta"
291+
elif t is dict:
292+
raise ValueError(
293+
"Untyped dict is not supported; please provide a concrete type, e.g., dict[str, Any]."
294+
)
285295
else:
286296
raise ValueError(f"type unsupported yet: {t}")
287297

0 commit comments

Comments
 (0)