Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 131 additions & 2 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import inspect
import warnings
from enum import Enum
from typing import Any, Callable, Mapping, get_origin
from typing import Any, Callable, Mapping, get_origin, TypeVar, overload

import numpy as np

Expand All @@ -26,6 +26,7 @@
encode_enriched_type,
is_namedtuple_type,
is_numpy_number_type,
extract_ndarray_elem_dtype,
ValueType,
FieldSchema,
BasicValueType,
Expand All @@ -34,6 +35,9 @@
)


T = TypeVar("T")


class ChildFieldPath:
"""Context manager to append a field to field_path on enter and pop it on exit."""

Expand Down Expand Up @@ -616,7 +620,7 @@ def dump_engine_object(v: Any) -> Any:
secs = int(total_secs)
nanos = int((total_secs - secs) * 1e9)
return {"secs": secs, "nanos": nanos}
elif hasattr(v, "__dict__"):
elif hasattr(v, "__dict__"): # for dataclass-like objects
s = {}
for k, val in v.__dict__.items():
if val is None:
Expand All @@ -633,3 +637,128 @@ def dump_engine_object(v: Any) -> Any:
elif isinstance(v, dict):
return {k: dump_engine_object(v) for k, v in v.items()}
return v


@overload
def load_engine_object(expected_type: type[T], v: Any) -> T: ...
@overload
def load_engine_object(expected_type: Any, v: Any) -> Any: ...
def load_engine_object(expected_type: Any, v: Any) -> Any:
"""Recursively load an object that was produced by dump_engine_object().

Args:
expected_type: The Python type annotation to reconstruct to.
v: The engine-facing Pythonized object (e.g., dict/list/primitive) to convert.

Returns:
A Python object matching the expected_type where possible.
"""
# Fast path
if v is None:
return None

type_info = analyze_type_info(expected_type)
variant = type_info.variant

# Any or unknown → return as-is
if isinstance(variant, AnalyzedAnyType) or type_info.base_type is Any:
return v

# Enum handling
if isinstance(expected_type, type) and issubclass(expected_type, Enum):
return expected_type(v)

# TimeDelta special form {secs, nanos}
if isinstance(variant, AnalyzedBasicType) and variant.kind == "TimeDelta":
if isinstance(v, Mapping) and "secs" in v and "nanos" in v:
secs = int(v["secs"]) # type: ignore[index]
nanos = int(v["nanos"]) # type: ignore[index]
return datetime.timedelta(seconds=secs, microseconds=nanos / 1_000)
return v

# List, NDArray (Vector-ish), or general sequences
if isinstance(variant, AnalyzedListType):
elem_type = variant.elem_type if variant.elem_type else Any
if type_info.base_type is np.ndarray:
# Reconstruct NDArray with appropriate dtype if available
try:
dtype = extract_ndarray_elem_dtype(type_info.core_type)
except (TypeError, ValueError, AttributeError):
dtype = None
return np.array(v, dtype=dtype)
# Regular Python list
return [load_engine_object(elem_type, item) for item in v]

# Dict / Mapping
if isinstance(variant, AnalyzedDictType):
key_t = variant.key_type
val_t = variant.value_type
return {
load_engine_object(key_t, k): load_engine_object(val_t, val)
for k, val in v.items()
}

# Structs (dataclass or NamedTuple)
if isinstance(variant, AnalyzedStructType):
struct_type = variant.struct_type
if dataclasses.is_dataclass(struct_type):
# Drop auxiliary discriminator "kind" if present
src = dict(v) if isinstance(v, Mapping) else v
if isinstance(src, Mapping):
init_kwargs: dict[str, Any] = {}
field_types = {f.name: f.type for f in dataclasses.fields(struct_type)}
for name, f_type in field_types.items():
if name in src:
init_kwargs[name] = load_engine_object(f_type, src[name])
# Construct with defaults for missing fields
return struct_type(**init_kwargs)
elif is_namedtuple_type(struct_type):
# NamedTuple is dumped as list/tuple of items
annotations = getattr(struct_type, "__annotations__", {})
field_names = list(getattr(struct_type, "_fields", ()))
values: list[Any] = []
for name in field_names:
f_type = annotations.get(name, Any)
# Assume v is a sequence aligned with fields
if isinstance(v, (list, tuple)):
idx = field_names.index(name)
values.append(load_engine_object(f_type, v[idx]))
elif isinstance(v, Mapping):
values.append(load_engine_object(f_type, v.get(name)))
else:
values.append(v)
return struct_type(*values)
return v

# Union with discriminator support via "kind"
if isinstance(variant, AnalyzedUnionType):
if isinstance(v, Mapping) and "kind" in v:
discriminator = v["kind"]
for typ in variant.variant_types:
t_info = analyze_type_info(typ)
if isinstance(t_info.variant, AnalyzedStructType):
t_struct = t_info.variant.struct_type
candidate_kind = getattr(t_struct, "kind", None)
if candidate_kind == discriminator:
# Remove discriminator for constructor
v_wo_kind = dict(v)
v_wo_kind.pop("kind", None)
return load_engine_object(t_struct, v_wo_kind)
# Fallback: try each variant until one succeeds
for typ in variant.variant_types:
try:
return load_engine_object(typ, v)
except (TypeError, ValueError):
continue
return v

# Basic types and everything else: handle numpy scalars and passthrough
if isinstance(v, np.ndarray) and type_info.base_type is list:
return v.tolist()
if isinstance(v, (list, tuple)) and type_info.base_type not in (list, tuple):
# If a non-sequence basic type expected, attempt direct cast
try:
return type_info.core_type(v)
except (TypeError, ValueError):
return v
return v
44 changes: 18 additions & 26 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .subprocess_exec import executor_stub
from .convert import (
dump_engine_object,
load_engine_object,
make_engine_value_encoder,
make_engine_value_decoder,
make_engine_key_decoder,
Expand Down Expand Up @@ -90,15 +91,6 @@ class Executor(Protocol):
op_category: OpCategory


def _load_spec_from_engine(
spec_loader: Callable[..., Any], spec: dict[str, Any]
) -> Any:
"""
Load a spec from the engine.
"""
return spec_loader(**spec)


def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
method = getattr(cls, name, None)
if method is None:
Expand All @@ -109,17 +101,17 @@ def _get_required_method(cls: type, name: str) -> Callable[..., Any]:


class _EngineFunctionExecutorFactory:
_spec_loader: Callable[..., Any]
_spec_loader: Callable[[Any], Any]
_executor_cls: type

def __init__(self, spec_loader: Callable[..., Any], executor_cls: type):
self._spec_loader = spec_loader
self._executor_cls = executor_cls

def __call__(
self, spec: dict[str, Any], *args: Any, **kwargs: Any
self, raw_spec: dict[str, Any], *args: Any, **kwargs: Any
) -> tuple[dict[str, Any], Executor]:
spec = _load_spec_from_engine(self._spec_loader, spec)
spec = self._spec_loader(raw_spec)
executor = self._executor_cls(spec)
result_type = executor.analyze_schema(*args, **kwargs)
return (result_type, executor)
Expand Down Expand Up @@ -378,7 +370,7 @@ def _inner(cls: type[Executor]) -> type:
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
expected_return=sig.return_annotation,
executor_factory=cls,
spec_loader=spec_cls,
spec_loader=lambda v: load_engine_object(spec_cls, v),
op_kind=spec_cls.__name__,
op_args=op_args,
)
Expand Down Expand Up @@ -414,7 +406,7 @@ def _inner(fn: Callable[..., Any]) -> Callable[..., Any]:
expected_args=list(sig.parameters.items()),
expected_return=sig.return_annotation,
executor_factory=_SimpleFunctionExecutor,
spec_loader=lambda: fn,
spec_loader=lambda _: fn,
op_kind=op_kind,
op_args=op_args,
)
Expand Down Expand Up @@ -469,9 +461,9 @@ class _TargetConnector:
The connector class passed to the engine.
"""

_spec_cls: type
_state_cls: type
_connector_cls: type
_spec_cls: type[Any]
_state_cls: type[Any]
_connector_cls: type[Any]

_get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
_apply_setup_change_async_fn: Callable[
Expand All @@ -480,7 +472,9 @@ class _TargetConnector:
_mutate_async_fn: Callable[..., Awaitable[None]]
_mutatation_type: AnalyzedDictType | None

def __init__(self, spec_cls: type, state_cls: type, connector_cls: type):
def __init__(
self, spec_cls: type[Any], state_cls: type[Any], connector_cls: type[Any]
):
self._spec_cls = spec_cls
self._state_cls = state_cls
self._connector_cls = connector_cls
Expand Down Expand Up @@ -546,7 +540,7 @@ def _analyze_mutate_mutation_type(
def create_export_context(
self,
name: str,
spec: dict[str, Any],
raw_spec: dict[str, Any],
raw_key_fields_schema: list[Any],
raw_value_fields_schema: list[Any],
) -> _TargetConnectorContext:
Expand All @@ -568,10 +562,10 @@ def create_export_context(
["<value>"], value_fields_schema, analyze_type_info(value_annotation)
)

loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
spec = load_engine_object(self._spec_cls, raw_spec)
return _TargetConnectorContext(
target_name=name,
spec=loaded_spec,
spec=spec,
prepared_spec=None,
key_fields_schema=key_fields_schema,
key_decoder=key_decoder,
Expand Down Expand Up @@ -638,13 +632,11 @@ async def apply_setup_changes_async(
) -> None:
for key, previous, current in changes:
prev_specs = [
_load_spec_from_engine(self._state_cls, spec)
if spec is not None
else None
load_engine_object(self._state_cls, spec) if spec is not None else None
for spec in previous
]
curr_spec = (
_load_spec_from_engine(self._state_cls, current)
load_engine_object(self._state_cls, current)
if current is not None
else None
)
Expand Down Expand Up @@ -678,7 +670,7 @@ async def mutate_async(


def target_connector(
spec_cls: type, state_cls: type | None = None
spec_cls: type[Any], state_cls: type[Any] | None = None
) -> Callable[[type], type]:
"""
Decorate a class to provide a target connector for an op.
Expand Down
Loading
Loading