Skip to content

Commit 68a6e46

Browse files
committed
feat: split convert.py into 2 modules engine_object and engine_value along with relevant tests
1 parent ac786d7 commit 68a6e46

File tree

8 files changed

+385
-336
lines changed

8 files changed

+385
-336
lines changed

python/cocoindex/auth_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Generic, TypeVar
77

88
from . import _engine # type: ignore
9-
from .convert import dump_engine_object, load_engine_object
9+
from .engine_object import dump_engine_object, load_engine_object
1010

1111
T = TypeVar("T")
1212

python/cocoindex/engine_object.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
"""
2+
Utilities to dump/load objects (for configs, specs).
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import datetime
8+
import dataclasses
9+
from enum import Enum
10+
from typing import Any, Mapping, TypeVar, overload, get_origin
11+
12+
import numpy as np
13+
14+
from .typing import (
15+
AnalyzedAnyType,
16+
AnalyzedBasicType,
17+
AnalyzedDictType,
18+
AnalyzedListType,
19+
AnalyzedStructType,
20+
AnalyzedTypeInfo,
21+
AnalyzedUnionType,
22+
EnrichedValueType,
23+
FieldSchema,
24+
analyze_type_info,
25+
encode_enriched_type,
26+
is_namedtuple_type,
27+
is_pydantic_model,
28+
extract_ndarray_elem_dtype,
29+
)
30+
31+
32+
T = TypeVar("T")
33+
34+
try:
35+
import pydantic, pydantic_core
36+
except ImportError:
37+
pass
38+
39+
40+
def _get_auto_default_for_type(
41+
type_info: AnalyzedTypeInfo,
42+
) -> tuple[Any, bool]:
43+
"""
44+
Get an auto-default value for a type annotation if it's safe to do so.
45+
46+
Returns:
47+
A tuple of (default_value, is_supported) where:
48+
- default_value: The default value if auto-defaulting is supported
49+
- is_supported: True if auto-defaulting is supported for this type
50+
"""
51+
# Case 1: Nullable types (Optional[T] or T | None)
52+
if type_info.nullable:
53+
return None, True
54+
55+
# Case 2: Table types (KTable or LTable) - check if it's a list or dict type
56+
if isinstance(type_info.variant, AnalyzedListType):
57+
return [], True
58+
elif isinstance(type_info.variant, AnalyzedDictType):
59+
return {}, True
60+
61+
return None, False
62+
63+
64+
def dump_engine_object(v: Any) -> Any:
65+
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
66+
if v is None:
67+
return None
68+
elif isinstance(v, EnrichedValueType):
69+
return v.encode()
70+
elif isinstance(v, FieldSchema):
71+
return v.encode()
72+
elif isinstance(v, type) or get_origin(v) is not None:
73+
return encode_enriched_type(v)
74+
elif isinstance(v, Enum):
75+
return v.value
76+
elif isinstance(v, datetime.timedelta):
77+
total_secs = v.total_seconds()
78+
secs = int(total_secs)
79+
nanos = int((total_secs - secs) * 1e9)
80+
return {"secs": secs, "nanos": nanos}
81+
elif is_namedtuple_type(type(v)):
82+
# Handle NamedTuple objects specifically to use dict format
83+
field_names = list(getattr(type(v), "_fields", ()))
84+
result = {}
85+
for name in field_names:
86+
val = getattr(v, name)
87+
result[name] = dump_engine_object(val) # Include all values, including None
88+
if hasattr(v, "kind") and "kind" not in result:
89+
result["kind"] = v.kind
90+
return result
91+
elif hasattr(v, "__dict__"): # for dataclass-like objects
92+
s = {}
93+
for k, val in v.__dict__.items():
94+
if val is None:
95+
# Skip None values
96+
continue
97+
s[k] = dump_engine_object(val)
98+
if hasattr(v, "kind") and "kind" not in s:
99+
s["kind"] = v.kind
100+
return s
101+
elif isinstance(v, (list, tuple)):
102+
return [dump_engine_object(item) for item in v]
103+
elif isinstance(v, np.ndarray):
104+
return v.tolist()
105+
elif isinstance(v, dict):
106+
return {k: dump_engine_object(v) for k, v in v.items()}
107+
return v
108+
109+
110+
@overload
111+
def load_engine_object(expected_type: type[T], v: Any) -> T: ...
112+
@overload
113+
def load_engine_object(expected_type: Any, v: Any) -> Any: ...
114+
def load_engine_object(expected_type: Any, v: Any) -> Any:
115+
"""Recursively load an object that was produced by dump_engine_object().
116+
117+
Args:
118+
expected_type: The Python type annotation to reconstruct to.
119+
v: The engine-facing Pythonized object (e.g., dict/list/primitive) to convert.
120+
121+
Returns:
122+
A Python object matching the expected_type where possible.
123+
"""
124+
# Fast path
125+
if v is None:
126+
return None
127+
128+
type_info = analyze_type_info(expected_type)
129+
variant = type_info.variant
130+
131+
if type_info.core_type is EnrichedValueType:
132+
return EnrichedValueType.decode(v)
133+
if type_info.core_type is FieldSchema:
134+
return FieldSchema.decode(v)
135+
136+
# Any or unknown → return as-is
137+
if isinstance(variant, AnalyzedAnyType) or type_info.base_type is Any:
138+
return v
139+
140+
# Enum handling
141+
if isinstance(expected_type, type) and issubclass(expected_type, Enum):
142+
return expected_type(v)
143+
144+
# TimeDelta special form {secs, nanos}
145+
if isinstance(variant, AnalyzedBasicType) and variant.kind == "TimeDelta":
146+
if isinstance(v, Mapping) and "secs" in v and "nanos" in v:
147+
secs = int(v["secs"]) # type: ignore[index]
148+
nanos = int(v["nanos"]) # type: ignore[index]
149+
return datetime.timedelta(seconds=secs, microseconds=nanos / 1_000)
150+
return v
151+
152+
# List, NDArray (Vector-ish), or general sequences
153+
if isinstance(variant, AnalyzedListType):
154+
elem_type = variant.elem_type if variant.elem_type else Any
155+
if type_info.base_type is np.ndarray:
156+
# Reconstruct NDArray with appropriate dtype if available
157+
try:
158+
dtype = extract_ndarray_elem_dtype(type_info.core_type)
159+
except (TypeError, ValueError, AttributeError):
160+
dtype = None
161+
return np.array(v, dtype=dtype)
162+
# Regular Python list
163+
return [load_engine_object(elem_type, item) for item in v]
164+
165+
# Dict / Mapping
166+
if isinstance(variant, AnalyzedDictType):
167+
key_t = variant.key_type
168+
val_t = variant.value_type
169+
return {
170+
load_engine_object(key_t, k): load_engine_object(val_t, val)
171+
for k, val in v.items()
172+
}
173+
174+
# Structs (dataclass, NamedTuple, or Pydantic)
175+
if isinstance(variant, AnalyzedStructType):
176+
struct_type = variant.struct_type
177+
init_kwargs: dict[str, Any] = {}
178+
missing_fields: list[tuple[str, Any]] = []
179+
if dataclasses.is_dataclass(struct_type):
180+
if not isinstance(v, Mapping):
181+
raise ValueError(f"Expected dict for dataclass, got {type(v)}")
182+
183+
for dc_field in dataclasses.fields(struct_type):
184+
if dc_field.name in v:
185+
init_kwargs[dc_field.name] = load_engine_object(
186+
dc_field.type, v[dc_field.name]
187+
)
188+
else:
189+
if (
190+
dc_field.default is dataclasses.MISSING
191+
and dc_field.default_factory is dataclasses.MISSING
192+
):
193+
missing_fields.append((dc_field.name, dc_field.type))
194+
195+
elif is_namedtuple_type(struct_type):
196+
if not isinstance(v, Mapping):
197+
raise ValueError(f"Expected dict for NamedTuple, got {type(v)}")
198+
# Dict format (from dump/load functions)
199+
annotations = getattr(struct_type, "__annotations__", {})
200+
field_names = list(getattr(struct_type, "_fields", ()))
201+
field_defaults = getattr(struct_type, "_field_defaults", {})
202+
203+
for name in field_names:
204+
f_type = annotations.get(name, Any)
205+
if name in v:
206+
init_kwargs[name] = load_engine_object(f_type, v[name])
207+
elif name not in field_defaults:
208+
missing_fields.append((name, f_type))
209+
210+
elif is_pydantic_model(struct_type):
211+
if not isinstance(v, Mapping):
212+
raise ValueError(f"Expected dict for Pydantic model, got {type(v)}")
213+
214+
model_fields: dict[str, pydantic.fields.FieldInfo]
215+
if hasattr(struct_type, "model_fields"):
216+
model_fields = struct_type.model_fields # type: ignore[attr-defined]
217+
else:
218+
model_fields = {}
219+
220+
for name, pyd_field in model_fields.items():
221+
if name in v:
222+
init_kwargs[name] = load_engine_object(
223+
pyd_field.annotation, v[name]
224+
)
225+
elif (
226+
getattr(pyd_field, "default", pydantic_core.PydanticUndefined)
227+
is pydantic_core.PydanticUndefined
228+
and getattr(pyd_field, "default_factory") is None
229+
):
230+
missing_fields.append((name, pyd_field.annotation))
231+
else:
232+
assert False, "Unsupported struct type"
233+
234+
for name, f_type in missing_fields:
235+
type_info = analyze_type_info(f_type)
236+
auto_default, is_supported = _get_auto_default_for_type(type_info)
237+
if is_supported:
238+
init_kwargs[name] = auto_default
239+
return struct_type(**init_kwargs)
240+
241+
# Union with discriminator support via "kind"
242+
if isinstance(variant, AnalyzedUnionType):
243+
if isinstance(v, Mapping) and "kind" in v:
244+
discriminator = v["kind"]
245+
for typ in variant.variant_types:
246+
t_info = analyze_type_info(typ)
247+
if isinstance(t_info.variant, AnalyzedStructType):
248+
t_struct = t_info.variant.struct_type
249+
candidate_kind = getattr(t_struct, "kind", None)
250+
if candidate_kind == discriminator:
251+
# Remove discriminator for constructor
252+
v_wo_kind = dict(v)
253+
v_wo_kind.pop("kind", None)
254+
return load_engine_object(t_struct, v_wo_kind)
255+
# Fallback: try each variant until one succeeds
256+
for typ in variant.variant_types:
257+
try:
258+
return load_engine_object(typ, v)
259+
except (TypeError, ValueError):
260+
continue
261+
return v
262+
263+
# Basic types and everything else: handle numpy scalars and passthrough
264+
if isinstance(v, np.ndarray) and type_info.base_type is list:
265+
return v.tolist()
266+
if isinstance(v, (list, tuple)) and type_info.base_type not in (list, tuple):
267+
# If a non-sequence basic type expected, attempt direct cast
268+
try:
269+
return type_info.core_type(v)
270+
except (TypeError, ValueError):
271+
return v
272+
return v

0 commit comments

Comments
 (0)