Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def make_engine_value_decoder(
if dst_is_any:
if src_type_kind == "Union":
return lambda value: value[1]
if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES:
if src_type_kind == "Struct":
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
if src_type_kind in TABLE_TYPES:
raise ValueError(
f"Missing type annotation for `{''.join(field_path)}`."
f"It's required for {src_type_kind} type."
Expand All @@ -98,6 +100,18 @@ def make_engine_value_decoder(

dst_type_info = analyze_type_info(dst_annotation)

# Handle struct -> dict binding for explicit dict annotations
if (
src_type_kind == "Struct"
and dst_type_info.kind == "KTable"
and dst_type_info.elem_type
and isinstance(dst_type_info.elem_type, tuple)
and len(dst_type_info.elem_type) == 2
and dst_type_info.elem_type[0] is str
and dst_type_info.elem_type[1] is Any
):
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])

if src_type_kind == "Union":
dst_type_variants = (
dst_type_info.union_variant_types
Expand Down Expand Up @@ -294,6 +308,34 @@ def make_closure_for_value(
)


def _make_engine_struct_to_dict_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
) -> Callable[[list[Any]], dict[str, Any]]:
"""Make a decoder from engine field values to a Python dict."""

field_decoders = []
for i, field_schema in enumerate(src_fields):
field_name = field_schema["name"]
field_path.append(f".{field_name}")
field_decoder = make_engine_value_decoder(
field_path,
field_schema["type"],
Any, # Use Any for recursive decoding
)
field_path.pop()
field_decoders.append((field_name, field_decoder))

def decode_to_dict(values: list[Any]) -> dict[str, Any]:
result = {}
for i, (field_name, field_decoder) in enumerate(field_decoders):
if i < len(values):
result[field_name] = field_decoder(values[i])
return result

return decode_to_dict


def dump_engine_object(v: Any) -> Any:
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
if v is None:
Expand Down
112 changes: 112 additions & 0 deletions python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,3 +1229,115 @@ class MixedStruct:
annotated_float=2.0,
)
validate_full_roundtrip(instance, MixedStruct)


def test_roundtrip_struct_to_dict_binding() -> None:
"""Test struct -> dict binding with Any annotation."""

@dataclass
class SimpleStruct:
name: str
value: int
price: float

instance = SimpleStruct("test", 42, 3.14)
expected_dict = {"name": "test", "value": 42, "price": 3.14}

# Test Any annotation
validate_full_roundtrip(instance, SimpleStruct, (expected_dict, Any))


def test_roundtrip_struct_to_dict_explicit() -> None:
"""Test struct -> dict binding with explicit dict annotations."""

@dataclass
class Product:
id: str
name: str
price: float
active: bool

instance = Product("P1", "Widget", 29.99, True)
expected_dict = {"id": "P1", "name": "Widget", "price": 29.99, "active": True}

# Test explicit dict annotations
validate_full_roundtrip(
instance, Product, (expected_dict, dict), (expected_dict, dict[str, Any])
)


def test_roundtrip_struct_to_dict_with_none_annotation() -> None:
"""Test struct -> dict binding with None annotation."""

@dataclass
class Config:
host: str
port: int
debug: bool

instance = Config("localhost", 8080, True)
expected_dict = {"host": "localhost", "port": 8080, "debug": True}

# Test None annotation (should be treated as Any)
validate_full_roundtrip(instance, Config, (expected_dict, None))


def test_roundtrip_struct_to_dict_nested() -> None:
"""Test struct -> dict binding with nested structs."""

@dataclass
class Address:
street: str
city: str

@dataclass
class Person:
name: str
age: int
address: Address

address = Address("123 Main St", "Anytown")
person = Person("John", 30, address)
expected_dict = {
"name": "John",
"age": 30,
"address": {"street": "123 Main St", "city": "Anytown"},
}

# Test nested struct conversion
validate_full_roundtrip(person, Person, (expected_dict, dict[str, Any]))


def test_roundtrip_struct_to_dict_with_list() -> None:
"""Test struct -> dict binding with list fields."""

@dataclass
class Team:
name: str
members: list[str]
active: bool

instance = Team("Dev Team", ["Alice", "Bob", "Charlie"], True)
expected_dict = {
"name": "Dev Team",
"members": ["Alice", "Bob", "Charlie"],
"active": True,
}

validate_full_roundtrip(instance, Team, (expected_dict, dict))


def test_roundtrip_namedtuple_to_dict_binding() -> None:
"""Test NamedTuple -> dict binding."""

class Point(NamedTuple):
x: float
y: float
z: float

instance = Point(1.0, 2.0, 3.0)
expected_dict = {"x": 1.0, "y": 2.0, "z": 3.0}

validate_full_roundtrip(
instance, Point, (expected_dict, dict), (expected_dict, Any)
)
8 changes: 7 additions & 1 deletion python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,10 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:

elif base_type is collections.abc.Mapping or base_type is dict:
args = typing.get_args(t)
elem_type = (args[0], args[1])
if len(args) == 0: # Handle untyped dict
elem_type = (str, Any)
else:
elem_type = (args[0], args[1])
kind = "KTable"
elif base_type in (types.UnionType, typing.Union):
possible_types = typing.get_args(t)
Expand Down Expand Up @@ -282,6 +285,9 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
kind = "OffsetDateTime"
elif t is datetime.timedelta:
kind = "TimeDelta"
elif t is dict:
elem_type = (str, Any)
kind = "KTable"
else:
raise ValueError(f"type unsupported yet: {t}")

Expand Down