Skip to content

Commit 854f0cd

Browse files
authored
feat(py-sdk-spec): add the full fledged load_engine_object() (#1022)
1 parent 5c73acc commit 854f0cd

File tree

3 files changed

+267
-28
lines changed

3 files changed

+267
-28
lines changed

python/cocoindex/convert.py

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import inspect
1010
import warnings
1111
from enum import Enum
12-
from typing import Any, Callable, Mapping, get_origin
12+
from typing import Any, Callable, Mapping, get_origin, TypeVar, overload
1313

1414
import numpy as np
1515

@@ -26,6 +26,7 @@
2626
encode_enriched_type,
2727
is_namedtuple_type,
2828
is_numpy_number_type,
29+
extract_ndarray_elem_dtype,
2930
ValueType,
3031
FieldSchema,
3132
BasicValueType,
@@ -34,6 +35,9 @@
3435
)
3536

3637

38+
T = TypeVar("T")
39+
40+
3741
class ChildFieldPath:
3842
"""Context manager to append a field to field_path on enter and pop it on exit."""
3943

@@ -616,7 +620,7 @@ def dump_engine_object(v: Any) -> Any:
616620
secs = int(total_secs)
617621
nanos = int((total_secs - secs) * 1e9)
618622
return {"secs": secs, "nanos": nanos}
619-
elif hasattr(v, "__dict__"):
623+
elif hasattr(v, "__dict__"): # for dataclass-like objects
620624
s = {}
621625
for k, val in v.__dict__.items():
622626
if val is None:
@@ -633,3 +637,128 @@ def dump_engine_object(v: Any) -> Any:
633637
elif isinstance(v, dict):
634638
return {k: dump_engine_object(v) for k, v in v.items()}
635639
return v
640+
641+
642+
@overload
643+
def load_engine_object(expected_type: type[T], v: Any) -> T: ...
644+
@overload
645+
def load_engine_object(expected_type: Any, v: Any) -> Any: ...
646+
def load_engine_object(expected_type: Any, v: Any) -> Any:
647+
"""Recursively load an object that was produced by dump_engine_object().
648+
649+
Args:
650+
expected_type: The Python type annotation to reconstruct to.
651+
v: The engine-facing Pythonized object (e.g., dict/list/primitive) to convert.
652+
653+
Returns:
654+
A Python object matching the expected_type where possible.
655+
"""
656+
# Fast path
657+
if v is None:
658+
return None
659+
660+
type_info = analyze_type_info(expected_type)
661+
variant = type_info.variant
662+
663+
# Any or unknown → return as-is
664+
if isinstance(variant, AnalyzedAnyType) or type_info.base_type is Any:
665+
return v
666+
667+
# Enum handling
668+
if isinstance(expected_type, type) and issubclass(expected_type, Enum):
669+
return expected_type(v)
670+
671+
# TimeDelta special form {secs, nanos}
672+
if isinstance(variant, AnalyzedBasicType) and variant.kind == "TimeDelta":
673+
if isinstance(v, Mapping) and "secs" in v and "nanos" in v:
674+
secs = int(v["secs"]) # type: ignore[index]
675+
nanos = int(v["nanos"]) # type: ignore[index]
676+
return datetime.timedelta(seconds=secs, microseconds=nanos / 1_000)
677+
return v
678+
679+
# List, NDArray (Vector-ish), or general sequences
680+
if isinstance(variant, AnalyzedListType):
681+
elem_type = variant.elem_type if variant.elem_type else Any
682+
if type_info.base_type is np.ndarray:
683+
# Reconstruct NDArray with appropriate dtype if available
684+
try:
685+
dtype = extract_ndarray_elem_dtype(type_info.core_type)
686+
except (TypeError, ValueError, AttributeError):
687+
dtype = None
688+
return np.array(v, dtype=dtype)
689+
# Regular Python list
690+
return [load_engine_object(elem_type, item) for item in v]
691+
692+
# Dict / Mapping
693+
if isinstance(variant, AnalyzedDictType):
694+
key_t = variant.key_type
695+
val_t = variant.value_type
696+
return {
697+
load_engine_object(key_t, k): load_engine_object(val_t, val)
698+
for k, val in v.items()
699+
}
700+
701+
# Structs (dataclass or NamedTuple)
702+
if isinstance(variant, AnalyzedStructType):
703+
struct_type = variant.struct_type
704+
if dataclasses.is_dataclass(struct_type):
705+
# Drop auxiliary discriminator "kind" if present
706+
src = dict(v) if isinstance(v, Mapping) else v
707+
if isinstance(src, Mapping):
708+
init_kwargs: dict[str, Any] = {}
709+
field_types = {f.name: f.type for f in dataclasses.fields(struct_type)}
710+
for name, f_type in field_types.items():
711+
if name in src:
712+
init_kwargs[name] = load_engine_object(f_type, src[name])
713+
# Construct with defaults for missing fields
714+
return struct_type(**init_kwargs)
715+
elif is_namedtuple_type(struct_type):
716+
# NamedTuple is dumped as list/tuple of items
717+
annotations = getattr(struct_type, "__annotations__", {})
718+
field_names = list(getattr(struct_type, "_fields", ()))
719+
values: list[Any] = []
720+
for name in field_names:
721+
f_type = annotations.get(name, Any)
722+
# Assume v is a sequence aligned with fields
723+
if isinstance(v, (list, tuple)):
724+
idx = field_names.index(name)
725+
values.append(load_engine_object(f_type, v[idx]))
726+
elif isinstance(v, Mapping):
727+
values.append(load_engine_object(f_type, v.get(name)))
728+
else:
729+
values.append(v)
730+
return struct_type(*values)
731+
return v
732+
733+
# Union with discriminator support via "kind"
734+
if isinstance(variant, AnalyzedUnionType):
735+
if isinstance(v, Mapping) and "kind" in v:
736+
discriminator = v["kind"]
737+
for typ in variant.variant_types:
738+
t_info = analyze_type_info(typ)
739+
if isinstance(t_info.variant, AnalyzedStructType):
740+
t_struct = t_info.variant.struct_type
741+
candidate_kind = getattr(t_struct, "kind", None)
742+
if candidate_kind == discriminator:
743+
# Remove discriminator for constructor
744+
v_wo_kind = dict(v)
745+
v_wo_kind.pop("kind", None)
746+
return load_engine_object(t_struct, v_wo_kind)
747+
# Fallback: try each variant until one succeeds
748+
for typ in variant.variant_types:
749+
try:
750+
return load_engine_object(typ, v)
751+
except (TypeError, ValueError):
752+
continue
753+
return v
754+
755+
# Basic types and everything else: handle numpy scalars and passthrough
756+
if isinstance(v, np.ndarray) and type_info.base_type is list:
757+
return v.tolist()
758+
if isinstance(v, (list, tuple)) and type_info.base_type not in (list, tuple):
759+
# If a non-sequence basic type expected, attempt direct cast
760+
try:
761+
return type_info.core_type(v)
762+
except (TypeError, ValueError):
763+
return v
764+
return v

python/cocoindex/op.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .subprocess_exec import executor_stub
2020
from .convert import (
2121
dump_engine_object,
22+
load_engine_object,
2223
make_engine_value_encoder,
2324
make_engine_value_decoder,
2425
make_engine_key_decoder,
@@ -90,15 +91,6 @@ class Executor(Protocol):
9091
op_category: OpCategory
9192

9293

93-
def _load_spec_from_engine(
94-
spec_loader: Callable[..., Any], spec: dict[str, Any]
95-
) -> Any:
96-
"""
97-
Load a spec from the engine.
98-
"""
99-
return spec_loader(**spec)
100-
101-
10294
def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
10395
method = getattr(cls, name, None)
10496
if method is None:
@@ -109,17 +101,17 @@ def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
109101

110102

111103
class _EngineFunctionExecutorFactory:
112-
_spec_loader: Callable[..., Any]
104+
_spec_loader: Callable[[Any], Any]
113105
_executor_cls: type
114106

115107
def __init__(self, spec_loader: Callable[..., Any], executor_cls: type):
116108
self._spec_loader = spec_loader
117109
self._executor_cls = executor_cls
118110

119111
def __call__(
120-
self, spec: dict[str, Any], *args: Any, **kwargs: Any
112+
self, raw_spec: dict[str, Any], *args: Any, **kwargs: Any
121113
) -> tuple[dict[str, Any], Executor]:
122-
spec = _load_spec_from_engine(self._spec_loader, spec)
114+
spec = self._spec_loader(raw_spec)
123115
executor = self._executor_cls(spec)
124116
result_type = executor.analyze_schema(*args, **kwargs)
125117
return (result_type, executor)
@@ -378,7 +370,7 @@ def _inner(cls: type[Executor]) -> type:
378370
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
379371
expected_return=sig.return_annotation,
380372
executor_factory=cls,
381-
spec_loader=spec_cls,
373+
spec_loader=lambda v: load_engine_object(spec_cls, v),
382374
op_kind=spec_cls.__name__,
383375
op_args=op_args,
384376
)
@@ -414,7 +406,7 @@ def _inner(fn: Callable[..., Any]) -> Callable[..., Any]:
414406
expected_args=list(sig.parameters.items()),
415407
expected_return=sig.return_annotation,
416408
executor_factory=_SimpleFunctionExecutor,
417-
spec_loader=lambda: fn,
409+
spec_loader=lambda _: fn,
418410
op_kind=op_kind,
419411
op_args=op_args,
420412
)
@@ -469,9 +461,9 @@ class _TargetConnector:
469461
The connector class passed to the engine.
470462
"""
471463

472-
_spec_cls: type
473-
_state_cls: type
474-
_connector_cls: type
464+
_spec_cls: type[Any]
465+
_state_cls: type[Any]
466+
_connector_cls: type[Any]
475467

476468
_get_persistent_key_fn: Callable[[_TargetConnectorContext, str], Any]
477469
_apply_setup_change_async_fn: Callable[
@@ -480,7 +472,9 @@ class _TargetConnector:
480472
_mutate_async_fn: Callable[..., Awaitable[None]]
481473
_mutatation_type: AnalyzedDictType | None
482474

483-
def __init__(self, spec_cls: type, state_cls: type, connector_cls: type):
475+
def __init__(
476+
self, spec_cls: type[Any], state_cls: type[Any], connector_cls: type[Any]
477+
):
484478
self._spec_cls = spec_cls
485479
self._state_cls = state_cls
486480
self._connector_cls = connector_cls
@@ -546,7 +540,7 @@ def _analyze_mutate_mutation_type(
546540
def create_export_context(
547541
self,
548542
name: str,
549-
spec: dict[str, Any],
543+
raw_spec: dict[str, Any],
550544
raw_key_fields_schema: list[Any],
551545
raw_value_fields_schema: list[Any],
552546
) -> _TargetConnectorContext:
@@ -568,10 +562,10 @@ def create_export_context(
568562
["<value>"], value_fields_schema, analyze_type_info(value_annotation)
569563
)
570564

571-
loaded_spec = _load_spec_from_engine(self._spec_cls, spec)
565+
spec = load_engine_object(self._spec_cls, raw_spec)
572566
return _TargetConnectorContext(
573567
target_name=name,
574-
spec=loaded_spec,
568+
spec=spec,
575569
prepared_spec=None,
576570
key_fields_schema=key_fields_schema,
577571
key_decoder=key_decoder,
@@ -638,13 +632,11 @@ async def apply_setup_changes_async(
638632
) -> None:
639633
for key, previous, current in changes:
640634
prev_specs = [
641-
_load_spec_from_engine(self._state_cls, spec)
642-
if spec is not None
643-
else None
635+
load_engine_object(self._state_cls, spec) if spec is not None else None
644636
for spec in previous
645637
]
646638
curr_spec = (
647-
_load_spec_from_engine(self._state_cls, current)
639+
load_engine_object(self._state_cls, current)
648640
if current is not None
649641
else None
650642
)
@@ -678,7 +670,7 @@ async def mutate_async(
678670

679671

680672
def target_connector(
681-
spec_cls: type, state_cls: type | None = None
673+
spec_cls: type[Any], state_cls: type[Any] | None = None
682674
) -> Callable[[type], type]:
683675
"""
684676
Decorate a class to provide a target connector for an op.

0 commit comments

Comments
 (0)