Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
import inspect
import warnings
from enum import Enum
from typing import Any, Callable, Mapping, Sequence, Type, get_origin
from typing import Any, Callable, Mapping, get_origin

import numpy as np

from .typing import (
TABLE_TYPES,
AnalyzedAnyType,
AnalyzedBasicType,
AnalyzedDictType,
Expand All @@ -27,6 +26,11 @@
encode_enriched_type,
is_namedtuple_type,
is_numpy_number_type,
ValueType,
FieldSchema,
BasicValueType,
StructType,
TableType,
)


Expand Down Expand Up @@ -172,7 +176,7 @@ def encode_basic_value(value: Any) -> Any:

def make_engine_key_decoder(
field_path: list[str],
key_fields_schema: list[dict[str, Any]],
key_fields_schema: list[FieldSchema],
dst_type_info: AnalyzedTypeInfo,
) -> Callable[[Any], Any]:
"""
Expand All @@ -183,7 +187,7 @@ def make_engine_key_decoder(
):
single_key_decoder = make_engine_value_decoder(
field_path,
key_fields_schema[0]["type"],
key_fields_schema[0].value_type.type,
dst_type_info,
for_key=True,
)
Expand All @@ -203,7 +207,7 @@ def key_decoder(value: list[Any]) -> Any:

def make_engine_value_decoder(
field_path: list[str],
src_type: dict[str, Any],
src_type: ValueType,
dst_type_info: AnalyzedTypeInfo,
for_key: bool = False,
) -> Callable[[Any], Any]:
Expand All @@ -219,7 +223,7 @@ def make_engine_value_decoder(
A decoder from an engine value to a Python value.
"""

src_type_kind = src_type["kind"]
src_type_kind = src_type.kind

dst_type_variant = dst_type_info.variant

Expand All @@ -229,19 +233,19 @@ def make_engine_value_decoder(
f"declared `{dst_type_info.core_type}`, an unsupported type"
)

if src_type_kind == "Struct":
if isinstance(src_type, StructType): # type: ignore[redundant-cast]
return make_engine_struct_decoder(
field_path,
src_type["fields"],
src_type.fields,
dst_type_info,
for_key=for_key,
)

if src_type_kind in TABLE_TYPES:
if isinstance(src_type, TableType): # type: ignore[redundant-cast]
with ChildFieldPath(field_path, "[*]"):
engine_fields_schema = src_type["row"]["fields"]
engine_fields_schema = src_type.row.fields

if src_type_kind == "LTable":
if src_type.kind == "LTable":
if isinstance(dst_type_variant, AnalyzedAnyType):
dst_elem_type = Any
elif isinstance(dst_type_variant, AnalyzedListType):
Expand All @@ -262,7 +266,7 @@ def decode(value: Any) -> Any | None:
return None
return [row_decoder(v) for v in value]

elif src_type_kind == "KTable":
elif src_type.kind == "KTable":
if isinstance(dst_type_variant, AnalyzedAnyType):
key_type, value_type = Any, Any
elif isinstance(dst_type_variant, AnalyzedDictType):
Expand All @@ -274,7 +278,7 @@ def decode(value: Any) -> Any | None:
f"declared `{dst_type_info.core_type}`, a dict type expected"
)

num_key_parts = src_type.get("num_key_parts", 1)
num_key_parts = src_type.num_key_parts or 1
key_decoder = make_engine_key_decoder(
field_path,
engine_fields_schema[0:num_key_parts],
Expand All @@ -298,7 +302,7 @@ def decode(value: Any) -> Any | None:

return decode

if src_type_kind == "Union":
if isinstance(src_type, BasicValueType) and src_type.kind == "Union":
if isinstance(dst_type_variant, AnalyzedAnyType):
return lambda value: value[1]

Expand All @@ -307,7 +311,10 @@ def decode(value: Any) -> Any | None:
if isinstance(dst_type_variant, AnalyzedUnionType)
else [dst_type_info]
)
src_type_variants = src_type["types"]
# mypy: union info exists for Union kind
assert src_type.union is not None # type: ignore[unreachable]
src_type_variants_basic: list[BasicValueType] = src_type.union.variants
src_type_variants = src_type_variants_basic
decoders = []
for i, src_type_variant in enumerate(src_type_variants):
with ChildFieldPath(field_path, f"[{i}]"):
Expand All @@ -331,7 +338,7 @@ def decode(value: Any) -> Any | None:
if isinstance(dst_type_variant, AnalyzedAnyType):
return lambda value: value

if src_type_kind == "Vector":
if isinstance(src_type, BasicValueType) and src_type.kind == "Vector":
field_path_str = "".join(field_path)
if not isinstance(dst_type_variant, AnalyzedListType):
raise ValueError(
Expand All @@ -350,9 +357,11 @@ def decode(value: Any) -> Any | None:
if is_numpy_number_type(dst_type_variant.elem_type):
scalar_dtype = dst_type_variant.elem_type
else:
# mypy: vector info exists for Vector kind
assert src_type.vector is not None # type: ignore[unreachable]
vec_elem_decoder = make_engine_value_decoder(
field_path + ["[*]"],
src_type["element_type"],
src_type.vector.element_type,
analyze_type_info(
dst_type_variant.elem_type if dst_type_variant else Any
),
Expand Down Expand Up @@ -432,7 +441,7 @@ def _get_auto_default_for_type(

def make_engine_struct_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
src_fields: list[FieldSchema],
dst_type_info: AnalyzedTypeInfo,
for_key: bool = False,
) -> Callable[[list[Any]], Any]:
Expand Down Expand Up @@ -461,7 +470,7 @@ def make_engine_struct_decoder(
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
)

src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)}
src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)}
dst_struct_type = dst_type_variant.struct_type

parameters: Mapping[str, inspect.Parameter]
Expand Down Expand Up @@ -493,7 +502,10 @@ def make_closure_for_field(
with ChildFieldPath(field_path, f".{name}"):
if src_idx is not None:
field_decoder = make_engine_value_decoder(
field_path, src_fields[src_idx]["type"], type_info, for_key=for_key
field_path,
src_fields[src_idx].value_type.type,
type_info,
for_key=for_key,
)
return lambda values: field_decoder(values[src_idx])

Expand Down Expand Up @@ -526,19 +538,19 @@ def make_closure_for_field(

def _make_engine_struct_to_dict_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
src_fields: list[FieldSchema],
value_type_annotation: Any,
) -> Callable[[list[Any] | None], dict[str, Any] | None]:
"""Make a decoder from engine field values to a Python dict."""

field_decoders = []
value_type_info = analyze_type_info(value_type_annotation)
for field_schema in src_fields:
field_name = field_schema["name"]
field_name = field_schema.name
with ChildFieldPath(field_path, f".{field_name}"):
field_decoder = make_engine_value_decoder(
field_path,
field_schema["type"],
field_schema.value_type.type,
value_type_info,
)
field_decoders.append((field_name, field_decoder))
Expand All @@ -560,19 +572,19 @@ def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:

def _make_engine_struct_to_tuple_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
src_fields: list[FieldSchema],
) -> Callable[[list[Any] | None], tuple[Any, ...] | None]:
"""Make a decoder from engine field values to a Python tuple."""

field_decoders = []
value_type_info = analyze_type_info(Any)
for field_schema in src_fields:
field_name = field_schema["name"]
field_name = field_schema.name
with ChildFieldPath(field_path, f".{field_name}"):
field_decoders.append(
make_engine_value_decoder(
field_path,
field_schema["type"],
field_schema.value_type.type,
value_type_info,
)
)
Expand Down
6 changes: 4 additions & 2 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .op import FunctionSpec
from .runtime import execution_context, to_async_call
from .setup import SetupChangeBundle
from .typing import analyze_type_info, encode_enriched_type
from .typing import analyze_type_info, encode_enriched_type, decode_engine_value_type
from .query_handler import QueryHandlerInfo, QueryHandlerResultFields
from .validation import (
validate_flow_name,
Expand Down Expand Up @@ -1164,7 +1164,9 @@ async def _build_flow_info_async(self) -> TransformFlowInfo:
inspect.signature(self._flow_fn).return_annotation
)
result_decoder = make_engine_value_decoder(
[], engine_return_type["type"], analyze_type_info(python_return_type)
[],
decode_engine_value_type(engine_return_type["type"]),
analyze_type_info(python_return_type),
)

return TransformFlowInfo(engine_flow, result_decoder)
Expand Down
13 changes: 10 additions & 3 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
analyze_type_info,
AnalyzedAnyType,
AnalyzedDictType,
EnrichedValueType,
decode_engine_field_schemas,
)
from .runtime import to_async_call

Expand Down Expand Up @@ -212,8 +214,9 @@ def process_arg(
TypeAttr(related_attr.value, actual_arg.analyzed_value)
)
type_info = analyze_type_info(arg_param.annotation)
enriched = EnrichedValueType.decode(actual_arg.value_type)
decoder = make_engine_value_decoder(
[arg_name], actual_arg.value_type["type"], type_info
[arg_name], enriched.type, type_info
)
is_required = not type_info.nullable
if is_required and actual_arg.value_type.get("nullable", False):
Expand Down Expand Up @@ -527,10 +530,14 @@ def create_export_context(
)

key_decoder = make_engine_key_decoder(
["(key)"], key_fields_schema, analyze_type_info(key_annotation)
["(key)"],
decode_engine_field_schemas(key_fields_schema),
analyze_type_info(key_annotation),
)
value_decoder = make_engine_struct_decoder(
["(value)"], value_fields_schema, analyze_type_info(value_annotation)
["(value)"],
decode_engine_field_schemas(value_fields_schema),
analyze_type_info(value_annotation),
)

loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
Expand Down
Loading
Loading