Skip to content

Commit 579aa96

Browse files
authored
feat(convert): create the encoder only once (#887)
1 parent fb92792 commit 579aa96

File tree

4 files changed

+81
-105
lines changed

4 files changed

+81
-105
lines changed

python/cocoindex/convert.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,21 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool
6767
ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty)
6868

6969

70-
def _make_encoder_closure(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
70+
def make_engine_value_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
7171
"""
7272
Create an encoder closure for a specific type.
7373
"""
7474
variant = type_info.variant
7575

76+
if isinstance(variant, AnalyzedUnknownType):
77+
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")
78+
7679
if isinstance(variant, AnalyzedListType):
7780
elem_type_info = (
7881
analyze_type_info(variant.elem_type) if variant.elem_type else ANY_TYPE_INFO
7982
)
8083
if isinstance(elem_type_info.variant, AnalyzedStructType):
81-
elem_encoder = _make_encoder_closure(elem_type_info)
84+
elem_encoder = make_engine_value_encoder(elem_type_info)
8285

8386
def encode_struct_list(value: Any) -> Any:
8487
return None if value is None else [elem_encoder(v) for v in value]
@@ -104,11 +107,13 @@ def encode_struct_dict(value: Any) -> Any:
104107
# Handle KTable case
105108
if value and is_struct_type(val_type):
106109
key_encoder = (
107-
_make_encoder_closure(analyze_type_info(key_type))
110+
make_engine_value_encoder(analyze_type_info(key_type))
108111
if is_struct_type(key_type)
109-
else _make_encoder_closure(ANY_TYPE_INFO)
112+
else make_engine_value_encoder(ANY_TYPE_INFO)
113+
)
114+
value_encoder = make_engine_value_encoder(
115+
analyze_type_info(val_type)
110116
)
111-
value_encoder = _make_encoder_closure(analyze_type_info(val_type))
112117
return [
113118
[key_encoder(k)] + value_encoder(v) for k, v in value.items()
114119
]
@@ -122,7 +127,7 @@ def encode_struct_dict(value: Any) -> Any:
122127
if dataclasses.is_dataclass(struct_type):
123128
fields = dataclasses.fields(struct_type)
124129
field_encoders = [
125-
_make_encoder_closure(analyze_type_info(f.type)) for f in fields
130+
make_engine_value_encoder(analyze_type_info(f.type)) for f in fields
126131
]
127132
field_names = [f.name for f in fields]
128133

@@ -140,7 +145,7 @@ def encode_dataclass(value: Any) -> Any:
140145
annotations = struct_type.__annotations__
141146
field_names = list(getattr(struct_type, "_fields", ()))
142147
field_encoders = [
143-
_make_encoder_closure(
148+
make_engine_value_encoder(
144149
analyze_type_info(annotations[name])
145150
if name in annotations
146151
else ANY_TYPE_INFO
@@ -170,38 +175,6 @@ def encode_basic_value(value: Any) -> Any:
170175
return encode_basic_value
171176

172177

173-
def make_engine_value_encoder(type_hint: Type[Any] | str) -> Callable[[Any], Any]:
174-
"""
175-
Create an encoder closure for converting Python values to engine values.
176-
177-
Args:
178-
type_hint: Type annotation for the values to encode
179-
180-
Returns:
181-
A closure that encodes Python values to engine values
182-
"""
183-
type_info = analyze_type_info(type_hint)
184-
if isinstance(type_info.variant, AnalyzedUnknownType):
185-
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")
186-
187-
return _make_encoder_closure(type_info)
188-
189-
190-
def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
191-
"""
192-
Encode a Python value to an engine value.
193-
194-
Args:
195-
value: The Python value to encode
196-
type_hint: Type annotation for the value. This should always be provided.
197-
198-
Returns:
199-
The encoded engine value
200-
"""
201-
encoder = make_engine_value_encoder(type_hint)
202-
return encoder(value)
203-
204-
205178
def make_engine_value_decoder(
206179
field_path: list[str],
207180
src_type: dict[str, Any],

python/cocoindex/flow.py

Lines changed: 47 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
from . import index
3333
from . import op
3434
from . import setting
35-
from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder
35+
from .convert import (
36+
dump_engine_object,
37+
make_engine_value_decoder,
38+
make_engine_value_encoder,
39+
)
3640
from .op import FunctionSpec
3741
from .runtime import execution_context
3842
from .setup import SetupChangeBundle
@@ -974,33 +978,60 @@ class TransformFlowInfo(NamedTuple):
974978
result_decoder: Callable[[Any], T]
975979

976980

981+
class FlowArgInfo(NamedTuple):
982+
name: str
983+
type_hint: Any
984+
encoder: Callable[[Any], Any]
985+
986+
977987
class TransformFlow(Generic[T]):
978988
"""
979989
A transient transformation flow that transforms in-memory data.
980990
"""
981991

982992
_flow_fn: Callable[..., DataSlice[T]]
983993
_flow_name: str
984-
_flow_arg_types: list[Any]
985-
_param_names: list[str]
994+
_args_info: list[FlowArgInfo]
986995

987996
_lazy_lock: asyncio.Lock
988997
_lazy_flow_info: TransformFlowInfo | None = None
989998

990999
def __init__(
9911000
self,
9921001
flow_fn: Callable[..., DataSlice[T]],
993-
flow_arg_types: Sequence[Any],
9941002
/,
9951003
name: str | None = None,
9961004
):
9971005
self._flow_fn = flow_fn
9981006
self._flow_name = _transform_flow_name_builder.build_name(
9991007
name, prefix="_transform_flow_"
10001008
)
1001-
self._flow_arg_types = list(flow_arg_types)
10021009
self._lazy_lock = asyncio.Lock()
10031010

1011+
sig = inspect.signature(flow_fn)
1012+
args_info = []
1013+
for param_name, param in sig.parameters.items():
1014+
if param.kind not in (
1015+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
1016+
inspect.Parameter.KEYWORD_ONLY,
1017+
):
1018+
raise ValueError(
1019+
f"Parameter `{param_name}` is not a parameter can be passed by name"
1020+
)
1021+
value_type_annotation: type | None = _get_data_slice_annotation_type(
1022+
param.annotation
1023+
)
1024+
if value_type_annotation is None:
1025+
raise ValueError(
1026+
f"Parameter `{param_name}` for {flow_fn} has no value type annotation. "
1027+
"Please use `cocoindex.DataSlice[T]` where T is the type of the value."
1028+
)
1029+
encoder = make_engine_value_encoder(
1030+
analyze_type_info(value_type_annotation)
1031+
)
1032+
args_info.append(FlowArgInfo(param_name, value_type_annotation, encoder))
1033+
self._args_info = args_info
1034+
10041035
def __call__(self, *args: Any, **kwargs: Any) -> DataSlice[T]:
10051036
return self._flow_fn(*args, **kwargs)
10061037

@@ -1020,31 +1051,15 @@ async def _flow_info_async(self) -> TransformFlowInfo:
10201051

10211052
async def _build_flow_info_async(self) -> TransformFlowInfo:
10221053
flow_builder_state = _FlowBuilderState(self._flow_name)
1023-
sig = inspect.signature(self._flow_fn)
1024-
if len(sig.parameters) != len(self._flow_arg_types):
1025-
raise ValueError(
1026-
f"Number of parameters in the flow function ({len(sig.parameters)}) "
1027-
f"does not match the number of argument types ({len(self._flow_arg_types)})"
1028-
)
1029-
10301054
kwargs: dict[str, DataSlice[T]] = {}
1031-
for (param_name, param), param_type in zip(
1032-
sig.parameters.items(), self._flow_arg_types
1033-
):
1034-
if param.kind not in (
1035-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
1036-
inspect.Parameter.KEYWORD_ONLY,
1037-
):
1038-
raise ValueError(
1039-
f"Parameter `{param_name}` is not a parameter can be passed by name"
1040-
)
1041-
encoded_type = encode_enriched_type(param_type)
1055+
for arg_info in self._args_info:
1056+
encoded_type = encode_enriched_type(arg_info.type_hint)
10421057
if encoded_type is None:
1043-
raise ValueError(f"Parameter `{param_name}` has no type annotation")
1058+
raise ValueError(f"Parameter `{arg_info.name}` has no type annotation")
10441059
engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
1045-
param_name, encoded_type
1060+
arg_info.name, encoded_type
10461061
)
1047-
kwargs[param_name] = DataSlice(
1062+
kwargs[arg_info.name] = DataSlice(
10481063
_DataSliceState(flow_builder_state, engine_ds)
10491064
)
10501065

@@ -1057,13 +1072,12 @@ async def _build_flow_info_async(self) -> TransformFlowInfo:
10571072
execution_context.event_loop
10581073
)
10591074
)
1060-
self._param_names = list(sig.parameters.keys())
10611075

10621076
engine_return_type = (
10631077
_data_slice_state(output).engine_data_slice.data_type().schema()
10641078
)
10651079
python_return_type: type[T] | None = _get_data_slice_annotation_type(
1066-
sig.return_annotation
1080+
inspect.signature(self._flow_fn).return_annotation
10671081
)
10681082
result_decoder = make_engine_value_decoder(
10691083
[], engine_return_type["type"], analyze_type_info(python_return_type)
@@ -1095,18 +1109,14 @@ async def eval_async(self, *args: Any, **kwargs: Any) -> T:
10951109
"""
10961110
flow_info = await self._flow_info_async()
10971111
params = []
1098-
for i, (arg, arg_type) in enumerate(
1099-
zip(self._param_names, self._flow_arg_types)
1100-
):
1101-
param_type = (
1102-
self._flow_arg_types[i] if i < len(self._flow_arg_types) else Any
1103-
)
1112+
for i, arg_info in enumerate(self._args_info):
11041113
if i < len(args):
1105-
params.append(encode_engine_value(args[i], type_hint=param_type))
1114+
arg = args[i]
11061115
elif arg in kwargs:
1107-
params.append(encode_engine_value(kwargs[arg], type_hint=param_type))
1116+
arg = kwargs[arg]
11081117
else:
11091118
raise ValueError(f"Parameter {arg} is not provided")
1119+
params.append(arg_info.encoder(arg))
11101120
engine_result = await flow_info.engine_flow.evaluate_async(params)
11111121
return flow_info.result_decoder(engine_result)
11121122

@@ -1117,27 +1127,7 @@ def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T]
11171127
"""
11181128

11191129
def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]) -> TransformFlow[T]:
1120-
sig = inspect.signature(fn)
1121-
arg_types = []
1122-
for param_name, param in sig.parameters.items():
1123-
if param.kind not in (
1124-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
1125-
inspect.Parameter.KEYWORD_ONLY,
1126-
):
1127-
raise ValueError(
1128-
f"Parameter `{param_name}` is not a parameter can be passed by name"
1129-
)
1130-
value_type_annotation: type[T] | None = _get_data_slice_annotation_type(
1131-
param.annotation
1132-
)
1133-
if value_type_annotation is None:
1134-
raise ValueError(
1135-
f"Parameter `{param_name}` for {fn} has no value type annotation. "
1136-
"Please use `cocoindex.DataSlice[T]` where T is the type of the value."
1137-
)
1138-
arg_types.append(value_type_annotation)
1139-
1140-
_transform_flow = TransformFlow(fn, arg_types)
1130+
_transform_flow = TransformFlow(fn)
11411131
functools.update_wrapper(_transform_flow, fn)
11421132
return _transform_flow
11431133

python/cocoindex/op.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818

1919
from . import _engine # type: ignore
2020
from .convert import (
21-
encode_engine_value,
21+
make_engine_value_encoder,
2222
make_engine_value_decoder,
2323
make_engine_struct_decoder,
2424
)
2525
from .typing import (
2626
TypeAttr,
27-
encode_enriched_type,
27+
encode_enriched_type_info,
2828
resolve_forward_ref,
2929
analyze_type_info,
3030
AnalyzedAnyType,
@@ -185,6 +185,7 @@ class _WrappedClass(executor_cls, _Fallback): # type: ignore[misc]
185185
_args_info: list[_ArgInfo]
186186
_kwargs_info: dict[str, _ArgInfo]
187187
_acall: Callable[..., Awaitable[Any]]
188+
_result_encoder: Callable[[Any], Any]
188189

189190
def __init__(self, spec: Any) -> None:
190191
super().__init__()
@@ -295,15 +296,19 @@ def process_arg(
295296

296297
base_analyze_method = getattr(self, "analyze", None)
297298
if base_analyze_method is not None:
298-
result = base_analyze_method(*args, **kwargs)
299+
result_type = base_analyze_method(*args, **kwargs)
299300
else:
300-
result = expected_return
301+
result_type = expected_return
301302
if len(attributes) > 0:
302-
result = Annotated[result, *attributes]
303+
result_type = Annotated[result_type, *attributes]
303304

304-
encoded_type = encode_enriched_type(result)
305+
analyzed_result_type_info = analyze_type_info(result_type)
306+
encoded_type = encode_enriched_type_info(analyzed_result_type_info)
305307
if potentially_missing_required_arg:
306308
encoded_type["nullable"] = True
309+
310+
self._result_encoder = make_engine_value_encoder(analyzed_result_type_info)
311+
307312
return encoded_type
308313

309314
async def prepare(self) -> None:
@@ -343,7 +348,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
343348
output = await self._acall(*decoded_args, **decoded_kwargs)
344349
else:
345350
output = await self._acall(*decoded_args, **decoded_kwargs)
346-
return encode_engine_value(output, type_hint=expected_return)
351+
return self._result_encoder(output)
347352

348353
_WrappedClass.__name__ = executor_cls.__name__
349354
_WrappedClass.__doc__ = executor_cls.__doc__

python/cocoindex/tests/test_convert.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
import uuid
44
from dataclasses import dataclass, make_dataclass
5-
from typing import Annotated, Any, Callable, Literal, NamedTuple
5+
from typing import Annotated, Any, Callable, Literal, NamedTuple, Type
66

77
import numpy as np
88
import pytest
@@ -11,7 +11,7 @@
1111
import cocoindex
1212
from cocoindex.convert import (
1313
dump_engine_object,
14-
encode_engine_value,
14+
make_engine_value_encoder,
1515
make_engine_value_decoder,
1616
)
1717
from cocoindex.typing import (
@@ -69,6 +69,14 @@ class CustomerNamedTuple(NamedTuple):
6969
tags: list[Tag] | None = None
7070

7171

72+
def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
73+
"""
74+
Encode a Python value to an engine value.
75+
"""
76+
encoder = make_engine_value_encoder(analyze_type_info(type_hint))
77+
return encoder(value)
78+
79+
7280
def build_engine_value_decoder(
7381
engine_type_in_py: Any, python_type: Any | None = None
7482
) -> Callable[[Any], Any]:

0 commit comments

Comments
 (0)