diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index b47a20870..84e004a3c 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -9,12 +9,11 @@ import inspect import warnings from enum import Enum -from typing import Any, Callable, Mapping, Sequence, Type, get_origin +from typing import Any, Callable, Mapping, get_origin import numpy as np from .typing import ( - TABLE_TYPES, AnalyzedAnyType, AnalyzedBasicType, AnalyzedDictType, @@ -27,6 +26,11 @@ encode_enriched_type, is_namedtuple_type, is_numpy_number_type, + ValueType, + FieldSchema, + BasicValueType, + StructType, + TableType, ) @@ -172,7 +176,7 @@ def encode_basic_value(value: Any) -> Any: def make_engine_key_decoder( field_path: list[str], - key_fields_schema: list[dict[str, Any]], + key_fields_schema: list[FieldSchema], dst_type_info: AnalyzedTypeInfo, ) -> Callable[[Any], Any]: """ @@ -183,7 +187,7 @@ def make_engine_key_decoder( ): single_key_decoder = make_engine_value_decoder( field_path, - key_fields_schema[0]["type"], + key_fields_schema[0].value_type.type, dst_type_info, for_key=True, ) @@ -203,7 +207,7 @@ def key_decoder(value: list[Any]) -> Any: def make_engine_value_decoder( field_path: list[str], - src_type: dict[str, Any], + src_type: ValueType, dst_type_info: AnalyzedTypeInfo, for_key: bool = False, ) -> Callable[[Any], Any]: @@ -219,7 +223,7 @@ def make_engine_value_decoder( A decoder from an engine value to a Python value. """ - src_type_kind = src_type["kind"] + src_type_kind = src_type.kind dst_type_variant = dst_type_info.variant @@ -229,19 +233,19 @@ def make_engine_value_decoder( f"declared `{dst_type_info.core_type}`, an unsupported type" ) - if src_type_kind == "Struct": + if isinstance(src_type, StructType): # type: ignore[redundant-cast] return make_engine_struct_decoder( field_path, - src_type["fields"], + src_type.fields, dst_type_info, for_key=for_key, ) - if src_type_kind in TABLE_TYPES: + if isinstance(src_type, TableType): # type: ignore[redundant-cast] with ChildFieldPath(field_path, "[*]"): - engine_fields_schema = src_type["row"]["fields"] + engine_fields_schema = src_type.row.fields - if src_type_kind == "LTable": + if src_type.kind == "LTable": if isinstance(dst_type_variant, AnalyzedAnyType): dst_elem_type = Any elif isinstance(dst_type_variant, AnalyzedListType): @@ -262,7 +266,7 @@ def decode(value: Any) -> Any | None: return None return [row_decoder(v) for v in value] - elif src_type_kind == "KTable": + elif src_type.kind == "KTable": if isinstance(dst_type_variant, AnalyzedAnyType): key_type, value_type = Any, Any elif isinstance(dst_type_variant, AnalyzedDictType): @@ -274,7 +278,7 @@ def decode(value: Any) -> Any | None: f"declared `{dst_type_info.core_type}`, a dict type expected" ) - num_key_parts = src_type.get("num_key_parts", 1) + num_key_parts = src_type.num_key_parts or 1 key_decoder = make_engine_key_decoder( field_path, engine_fields_schema[0:num_key_parts], @@ -298,7 +302,7 @@ def decode(value: Any) -> Any | None: return decode - if src_type_kind == "Union": + if isinstance(src_type, BasicValueType) and src_type.kind == "Union": if isinstance(dst_type_variant, AnalyzedAnyType): return lambda value: value[1] @@ -307,7 +311,10 @@ def decode(value: Any) -> Any | None: if isinstance(dst_type_variant, AnalyzedUnionType) else [dst_type_info] ) - src_type_variants = src_type["types"] + # mypy: union info exists for Union kind + assert src_type.union is not None # type: ignore[unreachable] + src_type_variants_basic: list[BasicValueType] = src_type.union.variants + src_type_variants = src_type_variants_basic decoders = [] for i, src_type_variant in enumerate(src_type_variants): with ChildFieldPath(field_path, f"[{i}]"): @@ -331,7 +338,7 @@ def decode(value: Any) -> Any | None: if isinstance(dst_type_variant, AnalyzedAnyType): return lambda value: value - if src_type_kind == "Vector": + if isinstance(src_type, BasicValueType) and src_type.kind == "Vector": field_path_str = "".join(field_path) if not isinstance(dst_type_variant, AnalyzedListType): raise ValueError( @@ -350,9 +357,11 @@ def decode(value: Any) -> Any | None: if is_numpy_number_type(dst_type_variant.elem_type): scalar_dtype = dst_type_variant.elem_type else: + # mypy: vector info exists for Vector kind + assert src_type.vector is not None # type: ignore[unreachable] vec_elem_decoder = make_engine_value_decoder( field_path + ["[*]"], - src_type["element_type"], + src_type.vector.element_type, analyze_type_info( dst_type_variant.elem_type if dst_type_variant else Any ), @@ -432,7 +441,7 @@ def _get_auto_default_for_type( def make_engine_struct_decoder( field_path: list[str], - src_fields: list[dict[str, Any]], + src_fields: list[FieldSchema], dst_type_info: AnalyzedTypeInfo, for_key: bool = False, ) -> Callable[[list[Any]], Any]: @@ -461,7 +470,7 @@ def make_engine_struct_decoder( f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected" ) - src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)} + src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)} dst_struct_type = dst_type_variant.struct_type parameters: Mapping[str, inspect.Parameter] @@ -493,7 +502,10 @@ def make_closure_for_field( with ChildFieldPath(field_path, f".{name}"): if src_idx is not None: field_decoder = make_engine_value_decoder( - field_path, src_fields[src_idx]["type"], type_info, for_key=for_key + field_path, + src_fields[src_idx].value_type.type, + type_info, + for_key=for_key, ) return lambda values: field_decoder(values[src_idx]) @@ -526,7 +538,7 @@ def make_closure_for_field( def _make_engine_struct_to_dict_decoder( field_path: list[str], - src_fields: list[dict[str, Any]], + src_fields: list[FieldSchema], value_type_annotation: Any, ) -> Callable[[list[Any] | None], dict[str, Any] | None]: """Make a decoder from engine field values to a Python dict.""" @@ -534,11 +546,11 @@ def _make_engine_struct_to_dict_decoder( field_decoders = [] value_type_info = analyze_type_info(value_type_annotation) for field_schema in src_fields: - field_name = field_schema["name"] + field_name = field_schema.name with ChildFieldPath(field_path, f".{field_name}"): field_decoder = make_engine_value_decoder( field_path, - field_schema["type"], + field_schema.value_type.type, value_type_info, ) field_decoders.append((field_name, field_decoder)) @@ -560,19 +572,19 @@ def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None: def _make_engine_struct_to_tuple_decoder( field_path: list[str], - src_fields: list[dict[str, Any]], + src_fields: list[FieldSchema], ) -> Callable[[list[Any] | None], tuple[Any, ...] | None]: """Make a decoder from engine field values to a Python tuple.""" field_decoders = [] value_type_info = analyze_type_info(Any) for field_schema in src_fields: - field_name = field_schema["name"] + field_name = field_schema.name with ChildFieldPath(field_path, f".{field_name}"): field_decoders.append( make_engine_value_decoder( field_path, - field_schema["type"], + field_schema.value_type.type, value_type_info, ) ) diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 02c74795a..f60d3bf69 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -40,7 +40,7 @@ from .op import FunctionSpec from .runtime import execution_context, to_async_call from .setup import SetupChangeBundle -from .typing import analyze_type_info, encode_enriched_type +from .typing import analyze_type_info, encode_enriched_type, decode_engine_value_type from .query_handler import QueryHandlerInfo, QueryHandlerResultFields from .validation import ( validate_flow_name, @@ -1164,7 +1164,9 @@ async def _build_flow_info_async(self) -> TransformFlowInfo: inspect.signature(self._flow_fn).return_annotation ) result_decoder = make_engine_value_decoder( - [], engine_return_type["type"], analyze_type_info(python_return_type) + [], + decode_engine_value_type(engine_return_type["type"]), + analyze_type_info(python_return_type), ) return TransformFlowInfo(engine_flow, result_decoder) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 21b44d163..6ad30892d 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -30,6 +30,8 @@ analyze_type_info, AnalyzedAnyType, AnalyzedDictType, + EnrichedValueType, + decode_engine_field_schemas, ) from .runtime import to_async_call @@ -212,8 +214,9 @@ def process_arg( TypeAttr(related_attr.value, actual_arg.analyzed_value) ) type_info = analyze_type_info(arg_param.annotation) + enriched = EnrichedValueType.decode(actual_arg.value_type) decoder = make_engine_value_decoder( - [arg_name], actual_arg.value_type["type"], type_info + [arg_name], enriched.type, type_info ) is_required = not type_info.nullable if is_required and actual_arg.value_type.get("nullable", False): @@ -527,10 +530,14 @@ def create_export_context( ) key_decoder = make_engine_key_decoder( - ["(key)"], key_fields_schema, analyze_type_info(key_annotation) + ["(key)"], + decode_engine_field_schemas(key_fields_schema), + analyze_type_info(key_annotation), ) value_decoder = make_engine_struct_decoder( - ["(value)"], value_fields_schema, analyze_type_info(value_annotation) + ["(value)"], + decode_engine_field_schemas(value_fields_schema), + analyze_type_info(value_annotation), ) loaded_spec = _load_spec_from_engine(self._spec_cls, spec) diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 902f18d02..a3364fe88 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -21,6 +21,7 @@ Vector, analyze_type_info, encode_enriched_type, + decode_engine_value_type, ) @@ -86,7 +87,9 @@ def build_engine_value_decoder( """ engine_type = encode_enriched_type(engine_type_in_py)["type"] return make_engine_value_decoder( - [], engine_type, analyze_type_info(python_type or engine_type_in_py) + [], + decode_engine_value_type(engine_type), + analyze_type_info(python_type or engine_type_in_py), ) @@ -116,7 +119,9 @@ def eq(a: Any, b: Any) -> bool: for other_value, other_type in decoded_values: decoder = make_engine_value_decoder( - [], encoded_output_type, analyze_type_info(other_type) + [], + decode_engine_value_type(encoded_output_type), + analyze_type_info(other_type), ) other_decoded_value = decoder(value_from_engine) assert eq(other_decoded_value, other_value), ( @@ -383,9 +388,19 @@ def test_roundtrip_json() -> None: def test_decode_scalar_numpy_values() -> None: test_cases = [ - ({"kind": "Int64"}, np.int64, 42, np.int64(42)), - ({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)), - ({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)), + (decode_engine_value_type({"kind": "Int64"}), np.int64, 42, np.int64(42)), + ( + decode_engine_value_type({"kind": "Float32"}), + np.float32, + 3.14, + np.float32(3.14), + ), + ( + decode_engine_value_type({"kind": "Float64"}), + np.float64, + 2.718, + np.float64(2.718), + ), ] for src_type, dst_type, input_value, expected in test_cases: decoder = make_engine_value_decoder( @@ -398,11 +413,13 @@ def test_decode_scalar_numpy_values() -> None: def test_non_ndarray_vector_decoding() -> None: # Test list[np.float64] - src_type = { - "kind": "Vector", - "element_type": {"kind": "Float64"}, - "dimension": None, - } + src_type = decode_engine_value_type( + { + "kind": "Vector", + "element_type": {"kind": "Float64"}, + "dimension": None, + } + ) dst_type_float = list[np.float64] decoder = make_engine_value_decoder( ["field"], src_type, analyze_type_info(dst_type_float) @@ -414,7 +431,9 @@ def test_non_ndarray_vector_decoding() -> None: assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)] # Test list[Uuid] - src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None} + src_type = decode_engine_value_type( + {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None} + ) dst_type_uuid = list[uuid.UUID] decoder = make_engine_value_decoder( ["field"], src_type, analyze_type_info(dst_type_uuid) @@ -895,11 +914,13 @@ class MyStructWithNDArray: def test_decode_nullable_ndarray_none_or_value_input() -> None: """Test decoding a nullable NDArray with None or value inputs.""" - src_type_dict = { - "kind": "Vector", - "element_type": {"kind": "Float32"}, - "dimension": None, - } + src_type_dict = decode_engine_value_type( + { + "kind": "Vector", + "element_type": {"kind": "Float32"}, + "dimension": None, + } + ) dst_annotation = NDArrayFloat32Type | None decoder = make_engine_value_decoder( [], src_type_dict, analyze_type_info(dst_annotation) @@ -921,11 +942,13 @@ def test_decode_nullable_ndarray_none_or_value_input() -> None: def test_decode_vector_string() -> None: """Test decoding a vector of strings works for Python native list type.""" - src_type_dict = { - "kind": "Vector", - "element_type": {"kind": "Str"}, - "dimension": None, - } + src_type_dict = decode_engine_value_type( + { + "kind": "Vector", + "element_type": {"kind": "Str"}, + "dimension": None, + } + ) decoder = make_engine_value_decoder( [], src_type_dict, analyze_type_info(Vector[str]) ) @@ -934,11 +957,13 @@ def test_decode_vector_string() -> None: def test_decode_error_non_nullable_or_non_list_vector() -> None: """Test decoding errors for non-nullable vectors or non-list inputs.""" - src_type_dict = { - "kind": "Vector", - "element_type": {"kind": "Float32"}, - "dimension": None, - } + src_type_dict = decode_engine_value_type( + { + "kind": "Vector", + "element_type": {"kind": "Float32"}, + "dimension": None, + } + ) decoder = make_engine_value_decoder( [], src_type_dict, analyze_type_info(NDArrayFloat32Type) ) diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index e14b634bc..a74b15291 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -15,6 +15,7 @@ Protocol, TypeVar, overload, + Self, ) import numpy as np @@ -471,3 +472,157 @@ def resolve_forward_ref(t: Any) -> Any: if isinstance(t, str): return eval(t) # pylint: disable=eval-used return t + + +# ========================= Engine Schema Types (Python mirror of Rust) ========================= + + +@dataclasses.dataclass +class VectorTypeSchema: + element_type: "BasicValueType" + dimension: int | None + + @staticmethod + def decode(obj: dict[str, Any]) -> "VectorTypeSchema": + return VectorTypeSchema( + element_type=BasicValueType.decode(obj["element_type"]), + dimension=obj.get("dimension"), + ) + + +@dataclasses.dataclass +class UnionTypeSchema: + variants: list["BasicValueType"] + + @staticmethod + def decode(obj: dict[str, Any]) -> "UnionTypeSchema": + return UnionTypeSchema( + variants=[BasicValueType.decode(t) for t in obj["types"]] + ) + + +@dataclasses.dataclass +class BasicValueType: + """ + Mirror of Rust BasicValueType in JSON form. + + For Vector and Union kinds, extra fields are populated accordingly. + """ + + kind: Literal[ + "Bytes", + "Str", + "Bool", + "Int64", + "Float32", + "Float64", + "Range", + "Uuid", + "Date", + "Time", + "LocalDateTime", + "OffsetDateTime", + "TimeDelta", + "Json", + "Vector", + "Union", + ] + vector: VectorTypeSchema | None = None + union: UnionTypeSchema | None = None + + @staticmethod + def decode(obj: dict[str, Any]) -> "BasicValueType": + kind = obj["kind"] + if kind == "Vector": + return BasicValueType( + kind=kind, # type: ignore[arg-type] + vector=VectorTypeSchema.decode(obj), + ) + if kind == "Union": + return BasicValueType( + kind=kind, # type: ignore[arg-type] + union=UnionTypeSchema.decode(obj), + ) + return BasicValueType(kind=kind) # type: ignore[arg-type] + + +@dataclasses.dataclass +class EnrichedValueType: + type: "ValueType" + nullable: bool = False + attrs: dict[str, Any] | None = None + + @staticmethod + def decode(obj: dict[str, Any]) -> "EnrichedValueType": + return EnrichedValueType( + type=decode_engine_value_type(obj["type"]), + nullable=obj.get("nullable", False), + attrs=obj.get("attrs"), + ) + + +@dataclasses.dataclass +class FieldSchema: + name: str + value_type: EnrichedValueType + + @staticmethod + def decode(obj: dict[str, Any]) -> "FieldSchema": + return FieldSchema(name=obj["name"], value_type=EnrichedValueType.decode(obj)) + + +@dataclasses.dataclass +class StructSchema: + fields: list[FieldSchema] + description: str | None = None + + @classmethod + def decode(cls, obj: dict[str, Any]) -> Self: + return cls( + fields=[FieldSchema.decode(f) for f in obj["fields"]], + description=obj.get("description"), + ) + + +@dataclasses.dataclass +class StructType(StructSchema): + kind: Literal["Struct"] = "Struct" + + +@dataclasses.dataclass +class TableType: + kind: Literal["KTable", "LTable"] + row: StructSchema + num_key_parts: int | None = None # Only for KTable + + @staticmethod + def decode(obj: dict[str, Any]) -> "TableType": + row_obj = obj["row"] + row = StructSchema( + fields=[FieldSchema.decode(f) for f in row_obj["fields"]], + description=row_obj.get("description"), + ) + return TableType( + kind=obj["kind"], # type: ignore[arg-type] + row=row, + num_key_parts=obj.get("num_key_parts"), + ) + + +ValueType = BasicValueType | StructType | TableType + + +def decode_engine_field_schemas(objs: list[dict[str, Any]]) -> list[FieldSchema]: + return [FieldSchema.decode(o) for o in objs] + + +def decode_engine_value_type(obj: dict[str, Any]) -> ValueType: + kind = obj["kind"] + if kind == "Struct": + return StructType.decode(obj) + + if kind in TABLE_TYPES: + return TableType.decode(obj) + + # Otherwise it's a basic value + return BasicValueType.decode(obj)