Skip to content

Commit 6485625

Browse files
committed
feat(convert): support explicit type hints in engine value encoding
1 parent b3d39f7 commit 6485625

File tree

3 files changed

+116
-38
lines changed

3 files changed

+116
-38
lines changed

python/cocoindex/convert.py

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,69 +28,147 @@
2828
is_struct_type,
2929
)
3030

31+
_CONVERTIBLE_KINDS = {
32+
("Float32", "Float64"),
33+
("LocalDateTime", "OffsetDateTime"),
34+
}
3135

32-
def encode_engine_value(
33-
value: Any, in_struct: bool = False, type_hint: Type[Any] | str | None = None
36+
37+
def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool:
38+
return (
39+
src_type_kind == dst_type_kind
40+
or (src_type_kind, dst_type_kind) in _CONVERTIBLE_KINDS
41+
)
42+
43+
44+
def _encode_engine_value_core(
45+
value: Any,
46+
in_struct: bool = False,
47+
type_hint: Type[Any] | str | None = None,
48+
type_variant: AnalyzedTypeInfo | None = None,
3449
) -> Any:
35-
"""Encode a Python value to an engine value."""
50+
"""Core encoding logic for converting Python values to engine values."""
51+
3652
if dataclasses.is_dataclass(value):
53+
fields = dataclasses.fields(value)
3754
return [
3855
encode_engine_value(
39-
getattr(value, f.name), in_struct=True, type_hint=f.type
56+
getattr(value, f.name),
57+
in_struct=True,
58+
type_hint=f.type,
4059
)
41-
for f in dataclasses.fields(value)
60+
for f in fields
4261
]
62+
4363
if is_namedtuple_type(type(value)):
4464
annotations = type(value).__annotations__
4565
return [
4666
encode_engine_value(
47-
getattr(value, name), in_struct=True, type_hint=annotations.get(name)
67+
getattr(value, name),
68+
in_struct=True,
69+
type_hint=annotations.get(name),
4870
)
4971
for name in value._fields
5072
]
73+
5174
if isinstance(value, np.number):
5275
return value.item()
76+
5377
if isinstance(value, np.ndarray):
5478
return value
79+
5580
if isinstance(value, (list, tuple)):
56-
return [encode_engine_value(v, in_struct) for v in value]
81+
if (
82+
type_variant
83+
and isinstance(type_variant.variant, AnalyzedListType)
84+
and type_variant.variant.elem_type
85+
):
86+
elem_encoder = make_engine_value_encoder(type_variant.variant.elem_type)
87+
return [elem_encoder(v) for v in value]
88+
else:
89+
return [encode_engine_value(v, in_struct) for v in value]
90+
5791
if isinstance(value, dict):
92+
# Determine if this is a JSON type
5893
is_json_type = False
59-
if type_hint:
60-
type_info = analyze_type_info(type_hint)
94+
if type_variant and isinstance(type_variant.variant, AnalyzedBasicType):
95+
is_json_type = type_variant.variant.kind == "Json"
96+
elif type_hint:
97+
hint_type_info = analyze_type_info(type_hint)
6198
is_json_type = (
62-
isinstance(type_info.variant, AnalyzedBasicType)
63-
and type_info.variant.kind == "Json"
99+
isinstance(hint_type_info.variant, AnalyzedBasicType)
100+
and hint_type_info.variant.kind == "Json"
64101
)
65102

66-
# For empty dicts, check type hints if in a struct context
67-
# when no contexts are provided, return an empty dict as default
68-
# TODO: always pass in the type annotation to make this robust
103+
# Handle empty dict
69104
if not value:
70105
if in_struct:
71106
return value if is_json_type else []
72-
return {}
107+
return {} if is_json_type else value
73108

109+
# Handle KTable
74110
first_val = next(iter(value.values()))
75-
if is_struct_type(type(first_val)): # KTable
111+
if is_struct_type(type(first_val)):
76112
return [
77113
[encode_engine_value(k, in_struct)] + encode_engine_value(v, in_struct)
78114
for k, v in value.items()
79115
]
116+
117+
# Handle regular dict
118+
if (
119+
type_variant
120+
and isinstance(type_variant.variant, AnalyzedDictType)
121+
and type_variant.variant.value_type
122+
):
123+
value_encoder = make_engine_value_encoder(type_variant.variant.value_type)
124+
return {k: value_encoder(v) for k, v in value.items()}
125+
80126
return value
81127

82128

83-
_CONVERTIBLE_KINDS = {
84-
("Float32", "Float64"),
85-
("LocalDateTime", "OffsetDateTime"),
86-
}
129+
def make_engine_value_encoder(type_annotation: Any) -> Callable[[Any], Any]:
130+
"""
131+
Make an encoder from a Python value to an engine value.
87132
133+
Args:
134+
type_annotation: The type annotation of the Python value.
88135
89-
def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool:
90-
return (
91-
src_type_kind == dst_type_kind
92-
or (src_type_kind, dst_type_kind) in _CONVERTIBLE_KINDS
93-
)
136+
Returns:
137+
An encoder from a Python value to an engine value.
138+
"""
139+
type_info = analyze_type_info(type_annotation)
140+
141+
if isinstance(type_info.variant, AnalyzedUnknownType):
142+
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")
143+
144+
def encode_value(value: Any, in_struct: bool = False) -> Any:
145+
return _encode_engine_value_core(
146+
value, in_struct=in_struct, type_variant=type_info
147+
)
148+
149+
return lambda value: encode_value(value, in_struct=False)
150+
151+
152+
def encode_engine_value(
153+
value: Any, in_struct: bool = False, type_hint: Type[Any] | str | None = None
154+
) -> Any:
155+
"""
156+
Encode a Python value to an engine value.
157+
158+
Args:
159+
value: The Python value to encode
160+
in_struct: Whether this value is being encoded within a struct context
161+
type_hint: Type annotation for the value. When provided, enables optimized
162+
type-aware encoding. For top-level calls, this should always be provided.
163+
164+
Returns:
165+
The encoded engine value
166+
"""
167+
if type_hint is not None:
168+
encoder = make_engine_value_encoder(type_hint)
169+
return encoder(value)
170+
171+
return _encode_engine_value_core(value, in_struct=in_struct)
94172

95173

96174
def make_engine_value_decoder(

python/cocoindex/flow.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,20 @@
99
import functools
1010
import inspect
1111
import re
12-
13-
from .validation import (
14-
validate_flow_name,
15-
NamingError,
16-
validate_full_flow_name,
17-
validate_target_name,
18-
)
19-
2012
from dataclasses import dataclass
2113
from enum import Enum
2214
from threading import Lock
2315
from typing import (
2416
Any,
2517
Callable,
2618
Generic,
19+
Iterable,
2720
NamedTuple,
2821
Sequence,
2922
TypeVar,
3023
cast,
3124
get_args,
3225
get_origin,
33-
Iterable,
3426
)
3527

3628
from rich.text import Text
@@ -45,6 +37,11 @@
4537
from .runtime import execution_context
4638
from .setup import SetupChangeBundle
4739
from .typing import encode_enriched_type
40+
from .validation import (
41+
validate_flow_name,
42+
validate_full_flow_name,
43+
validate_target_name,
44+
)
4845

4946

5047
class _NameBuilder:
@@ -1083,10 +1080,13 @@ async def eval_async(self, *args: Any, **kwargs: Any) -> T:
10831080
flow_info = await self._flow_info_async()
10841081
params = []
10851082
for i, arg in enumerate(self._param_names):
1083+
param_type = (
1084+
self._flow_arg_types[i] if i < len(self._flow_arg_types) else None
1085+
)
10861086
if i < len(args):
1087-
params.append(encode_engine_value(args[i]))
1087+
params.append(encode_engine_value(args[i], type_hint=param_type))
10881088
elif arg in kwargs:
1089-
params.append(encode_engine_value(kwargs[arg]))
1089+
params.append(encode_engine_value(kwargs[arg], type_hint=param_type))
10901090
else:
10911091
raise ValueError(f"Parameter {arg} is not provided")
10921092
engine_result = await flow_info.engine_flow.evaluate_async(params)

python/cocoindex/op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import dataclasses
77
import inspect
88
from enum import Enum
9-
from typing import Any, Awaitable, Callable, Protocol, dataclass_transform, Annotated
9+
from typing import Annotated, Any, Awaitable, Callable, Protocol, dataclass_transform
1010

1111
from . import _engine # type: ignore
1212
from .convert import encode_engine_value, make_engine_value_decoder
@@ -277,7 +277,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
277277
output = await self._acall(*decoded_args, **decoded_kwargs)
278278
else:
279279
output = await self._acall(*decoded_args, **decoded_kwargs)
280-
return encode_engine_value(output)
280+
return encode_engine_value(output, type_hint=expected_return)
281281

282282
_WrappedClass.__name__ = executor_cls.__name__
283283
_WrappedClass.__doc__ = executor_cls.__doc__

0 commit comments

Comments
 (0)