diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 52422d295..1a254ac28 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -89,13 +89,26 @@ def make_engine_value_decoder( if dst_is_any: if src_type_kind == "Union": return lambda value: value[1] - if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES: + if src_type_kind == "Struct": + return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"]) + if src_type_kind in TABLE_TYPES: raise ValueError( f"Missing type annotation for `{''.join(field_path)}`." f"It's required for {src_type_kind} type." ) return lambda value: value + # Handle struct -> dict binding for explicit dict annotations + is_dict_annotation = False + if dst_annotation is dict: + is_dict_annotation = True + elif getattr(dst_annotation, "__origin__", None) is dict: + args = getattr(dst_annotation, "__args__", ()) + if args == (str, Any): + is_dict_annotation = True + if is_dict_annotation and src_type_kind == "Struct": + return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"]) + dst_type_info = analyze_type_info(dst_annotation) if src_type_kind == "Union": @@ -294,6 +307,39 @@ def make_closure_for_value( ) +def _make_engine_struct_to_dict_decoder( + field_path: list[str], + src_fields: list[dict[str, Any]], +) -> Callable[[list[Any] | None], dict[str, Any] | None]: + """Make a decoder from engine field values to a Python dict.""" + + field_decoders = [] + for i, field_schema in enumerate(src_fields): + field_name = field_schema["name"] + field_path.append(f".{field_name}") + field_decoder = make_engine_value_decoder( + field_path, + field_schema["type"], + Any, # Use Any for recursive decoding + ) + field_path.pop() + field_decoders.append((field_name, field_decoder)) + + def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None: + if values is None: + return None + if len(field_decoders) != len(values): + raise ValueError( + f"Field count mismatch: expected {len(field_decoders)}, got {len(values)}" + ) + return { + field_name: field_decoder(value) + for value, (field_name, field_decoder) in zip(values, field_decoders) + } + + return decode_to_dict + + def dump_engine_object(v: Any) -> Any: """Recursively dump an object for engine. Engine side uses `Pythonized` to catch.""" if v is None: diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index e71676398..622b5c880 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -1229,3 +1229,115 @@ class MixedStruct: annotated_float=2.0, ) validate_full_roundtrip(instance, MixedStruct) + + +def test_roundtrip_struct_to_dict_binding() -> None: + """Test struct -> dict binding with Any annotation.""" + + @dataclass + class SimpleStruct: + name: str + value: int + price: float + + instance = SimpleStruct("test", 42, 3.14) + expected_dict = {"name": "test", "value": 42, "price": 3.14} + + # Test Any annotation + validate_full_roundtrip(instance, SimpleStruct, (expected_dict, Any)) + + +def test_roundtrip_struct_to_dict_explicit() -> None: + """Test struct -> dict binding with explicit dict annotations.""" + + @dataclass + class Product: + id: str + name: str + price: float + active: bool + + instance = Product("P1", "Widget", 29.99, True) + expected_dict = {"id": "P1", "name": "Widget", "price": 29.99, "active": True} + + # Test explicit dict annotations + validate_full_roundtrip( + instance, Product, (expected_dict, dict), (expected_dict, dict[str, Any]) + ) + + +def test_roundtrip_struct_to_dict_with_none_annotation() -> None: + """Test struct -> dict binding with None annotation.""" + + @dataclass + class Config: + host: str + port: int + debug: bool + + instance = Config("localhost", 8080, True) + expected_dict = {"host": "localhost", "port": 8080, "debug": True} + + # Test None annotation (should be treated as Any) + validate_full_roundtrip(instance, Config, (expected_dict, None)) + + +def test_roundtrip_struct_to_dict_nested() -> None: + """Test struct -> dict binding with nested structs.""" + + @dataclass + class Address: + street: str + city: str + + @dataclass + class Person: + name: str + age: int + address: Address + + address = Address("123 Main St", "Anytown") + person = Person("John", 30, address) + expected_dict = { + "name": "John", + "age": 30, + "address": {"street": "123 Main St", "city": "Anytown"}, + } + + # Test nested struct conversion + validate_full_roundtrip(person, Person, (expected_dict, dict[str, Any])) + + +def test_roundtrip_struct_to_dict_with_list() -> None: + """Test struct -> dict binding with list fields.""" + + @dataclass + class Team: + name: str + members: list[str] + active: bool + + instance = Team("Dev Team", ["Alice", "Bob", "Charlie"], True) + expected_dict = { + "name": "Dev Team", + "members": ["Alice", "Bob", "Charlie"], + "active": True, + } + + validate_full_roundtrip(instance, Team, (expected_dict, dict)) + + +def test_roundtrip_namedtuple_to_dict_binding() -> None: + """Test NamedTuple -> dict binding.""" + + class Point(NamedTuple): + x: float + y: float + z: float + + instance = Point(1.0, 2.0, 3.0) + expected_dict = {"x": 1.0, "y": 2.0, "z": 3.0} + + validate_full_roundtrip( + instance, Point, (expected_dict, dict), (expected_dict, Any) + ) diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index eb5fc9088..4540b97e8 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -168,7 +168,8 @@ class AnalyzedTypeInfo: def analyze_type_info(t: Any) -> AnalyzedTypeInfo: """ - Analyze a Python type and return the analyzed info. + Analyze a Python type annotation and extract CocoIndex-specific type information. + Only concrete CocoIndex type annotations are supported. Raises ValueError for Any, empty, or untyped dict types. """ if isinstance(t, tuple) and len(t) == 2: kt, vt = t @@ -241,7 +242,12 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: elif base_type is collections.abc.Mapping or base_type is dict: args = typing.get_args(t) - elem_type = (args[0], args[1]) + if len(args) == 0: # Handle untyped dict + raise ValueError( + "Untyped dict is not supported; please provide a concrete type, e.g., dict[str, Any]." + ) + else: + elem_type = (args[0], args[1]) kind = "KTable" elif base_type in (types.UnionType, typing.Union): possible_types = typing.get_args(t) @@ -282,6 +288,10 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: kind = "OffsetDateTime" elif t is datetime.timedelta: kind = "TimeDelta" + elif t is dict: + raise ValueError( + "Untyped dict is not supported; please provide a concrete type, e.g., dict[str, Any]." + ) else: raise ValueError(f"type unsupported yet: {t}")