diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 24cfd0d54..b34ddcf40 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -67,18 +67,21 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty) -def _make_encoder_closure(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]: +def make_engine_value_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]: """ Create an encoder closure for a specific type. """ variant = type_info.variant + if isinstance(variant, AnalyzedUnknownType): + raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported") + if isinstance(variant, AnalyzedListType): elem_type_info = ( analyze_type_info(variant.elem_type) if variant.elem_type else ANY_TYPE_INFO ) if isinstance(elem_type_info.variant, AnalyzedStructType): - elem_encoder = _make_encoder_closure(elem_type_info) + elem_encoder = make_engine_value_encoder(elem_type_info) def encode_struct_list(value: Any) -> Any: return None if value is None else [elem_encoder(v) for v in value] @@ -104,11 +107,13 @@ def encode_struct_dict(value: Any) -> Any: # Handle KTable case if value and is_struct_type(val_type): key_encoder = ( - _make_encoder_closure(analyze_type_info(key_type)) + make_engine_value_encoder(analyze_type_info(key_type)) if is_struct_type(key_type) - else _make_encoder_closure(ANY_TYPE_INFO) + else make_engine_value_encoder(ANY_TYPE_INFO) + ) + value_encoder = make_engine_value_encoder( + analyze_type_info(val_type) ) - value_encoder = _make_encoder_closure(analyze_type_info(val_type)) return [ [key_encoder(k)] + value_encoder(v) for k, v in value.items() ] @@ -122,7 +127,7 @@ def encode_struct_dict(value: Any) -> Any: if dataclasses.is_dataclass(struct_type): fields = dataclasses.fields(struct_type) field_encoders = [ - _make_encoder_closure(analyze_type_info(f.type)) for f in fields + make_engine_value_encoder(analyze_type_info(f.type)) for f in fields ] field_names = [f.name for f in fields] @@ -140,7 +145,7 @@ def encode_dataclass(value: Any) -> Any: annotations = struct_type.__annotations__ field_names = list(getattr(struct_type, "_fields", ())) field_encoders = [ - _make_encoder_closure( + make_engine_value_encoder( analyze_type_info(annotations[name]) if name in annotations else ANY_TYPE_INFO @@ -170,38 +175,6 @@ def encode_basic_value(value: Any) -> Any: return encode_basic_value -def make_engine_value_encoder(type_hint: Type[Any] | str) -> Callable[[Any], Any]: - """ - Create an encoder closure for converting Python values to engine values. - - Args: - type_hint: Type annotation for the values to encode - - Returns: - A closure that encodes Python values to engine values - """ - type_info = analyze_type_info(type_hint) - if isinstance(type_info.variant, AnalyzedUnknownType): - raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported") - - return _make_encoder_closure(type_info) - - -def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any: - """ - Encode a Python value to an engine value. - - Args: - value: The Python value to encode - type_hint: Type annotation for the value. This should always be provided. - - Returns: - The encoded engine value - """ - encoder = make_engine_value_encoder(type_hint) - return encoder(value) - - def make_engine_value_decoder( field_path: list[str], src_type: dict[str, Any], diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 8a93b1afe..510052eb7 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -32,7 +32,11 @@ from . import index from . import op from . import setting -from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder +from .convert import ( + dump_engine_object, + make_engine_value_decoder, + make_engine_value_encoder, +) from .op import FunctionSpec from .runtime import execution_context from .setup import SetupChangeBundle @@ -974,6 +978,12 @@ class TransformFlowInfo(NamedTuple): result_decoder: Callable[[Any], T] +class FlowArgInfo(NamedTuple): + name: str + type_hint: Any + encoder: Callable[[Any], Any] + + class TransformFlow(Generic[T]): """ A transient transformation flow that transforms in-memory data. @@ -981,8 +991,7 @@ class TransformFlow(Generic[T]): _flow_fn: Callable[..., DataSlice[T]] _flow_name: str - _flow_arg_types: list[Any] - _param_names: list[str] + _args_info: list[FlowArgInfo] _lazy_lock: asyncio.Lock _lazy_flow_info: TransformFlowInfo | None = None @@ -990,7 +999,6 @@ class TransformFlow(Generic[T]): def __init__( self, flow_fn: Callable[..., DataSlice[T]], - flow_arg_types: Sequence[Any], /, name: str | None = None, ): @@ -998,9 +1006,32 @@ def __init__( self._flow_name = _transform_flow_name_builder.build_name( name, prefix="_transform_flow_" ) - self._flow_arg_types = list(flow_arg_types) self._lazy_lock = asyncio.Lock() + sig = inspect.signature(flow_fn) + args_info = [] + for param_name, param in sig.parameters.items(): + if param.kind not in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + raise ValueError( + f"Parameter `{param_name}` is not a parameter can be passed by name" + ) + value_type_annotation: type | None = _get_data_slice_annotation_type( + param.annotation + ) + if value_type_annotation is None: + raise ValueError( + f"Parameter `{param_name}` for {flow_fn} has no value type annotation. " + "Please use `cocoindex.DataSlice[T]` where T is the type of the value." + ) + encoder = make_engine_value_encoder( + analyze_type_info(value_type_annotation) + ) + args_info.append(FlowArgInfo(param_name, value_type_annotation, encoder)) + self._args_info = args_info + def __call__(self, *args: Any, **kwargs: Any) -> DataSlice[T]: return self._flow_fn(*args, **kwargs) @@ -1020,31 +1051,15 @@ async def _flow_info_async(self) -> TransformFlowInfo: async def _build_flow_info_async(self) -> TransformFlowInfo: flow_builder_state = _FlowBuilderState(self._flow_name) - sig = inspect.signature(self._flow_fn) - if len(sig.parameters) != len(self._flow_arg_types): - raise ValueError( - f"Number of parameters in the flow function ({len(sig.parameters)}) " - f"does not match the number of argument types ({len(self._flow_arg_types)})" - ) - kwargs: dict[str, DataSlice[T]] = {} - for (param_name, param), param_type in zip( - sig.parameters.items(), self._flow_arg_types - ): - if param.kind not in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ): - raise ValueError( - f"Parameter `{param_name}` is not a parameter can be passed by name" - ) - encoded_type = encode_enriched_type(param_type) + for arg_info in self._args_info: + encoded_type = encode_enriched_type(arg_info.type_hint) if encoded_type is None: - raise ValueError(f"Parameter `{param_name}` has no type annotation") + raise ValueError(f"Parameter `{arg_info.name}` has no type annotation") engine_ds = flow_builder_state.engine_flow_builder.add_direct_input( - param_name, encoded_type + arg_info.name, encoded_type ) - kwargs[param_name] = DataSlice( + kwargs[arg_info.name] = DataSlice( _DataSliceState(flow_builder_state, engine_ds) ) @@ -1057,13 +1072,12 @@ async def _build_flow_info_async(self) -> TransformFlowInfo: execution_context.event_loop ) ) - self._param_names = list(sig.parameters.keys()) engine_return_type = ( _data_slice_state(output).engine_data_slice.data_type().schema() ) python_return_type: type[T] | None = _get_data_slice_annotation_type( - sig.return_annotation + inspect.signature(self._flow_fn).return_annotation ) result_decoder = make_engine_value_decoder( [], engine_return_type["type"], analyze_type_info(python_return_type) @@ -1095,18 +1109,14 @@ async def eval_async(self, *args: Any, **kwargs: Any) -> T: """ flow_info = await self._flow_info_async() params = [] - for i, (arg, arg_type) in enumerate( - zip(self._param_names, self._flow_arg_types) - ): - param_type = ( - self._flow_arg_types[i] if i < len(self._flow_arg_types) else Any - ) + for i, arg_info in enumerate(self._args_info): if i < len(args): - params.append(encode_engine_value(args[i], type_hint=param_type)) + arg = args[i] elif arg in kwargs: - params.append(encode_engine_value(kwargs[arg], type_hint=param_type)) + arg = kwargs[arg] else: raise ValueError(f"Parameter {arg} is not provided") + params.append(arg_info.encoder(arg)) engine_result = await flow_info.engine_flow.evaluate_async(params) return flow_info.result_decoder(engine_result) @@ -1117,27 +1127,7 @@ def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T] """ def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]) -> TransformFlow[T]: - sig = inspect.signature(fn) - arg_types = [] - for param_name, param in sig.parameters.items(): - if param.kind not in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ): - raise ValueError( - f"Parameter `{param_name}` is not a parameter can be passed by name" - ) - value_type_annotation: type[T] | None = _get_data_slice_annotation_type( - param.annotation - ) - if value_type_annotation is None: - raise ValueError( - f"Parameter `{param_name}` for {fn} has no value type annotation. " - "Please use `cocoindex.DataSlice[T]` where T is the type of the value." - ) - arg_types.append(value_type_annotation) - - _transform_flow = TransformFlow(fn, arg_types) + _transform_flow = TransformFlow(fn) functools.update_wrapper(_transform_flow, fn) return _transform_flow diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 95a80ec93..d2422ba3c 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -18,13 +18,13 @@ from . import _engine # type: ignore from .convert import ( - encode_engine_value, + make_engine_value_encoder, make_engine_value_decoder, make_engine_struct_decoder, ) from .typing import ( TypeAttr, - encode_enriched_type, + encode_enriched_type_info, resolve_forward_ref, analyze_type_info, AnalyzedAnyType, @@ -185,6 +185,7 @@ class _WrappedClass(executor_cls, _Fallback): # type: ignore[misc] _args_info: list[_ArgInfo] _kwargs_info: dict[str, _ArgInfo] _acall: Callable[..., Awaitable[Any]] + _result_encoder: Callable[[Any], Any] def __init__(self, spec: Any) -> None: super().__init__() @@ -295,15 +296,19 @@ def process_arg( base_analyze_method = getattr(self, "analyze", None) if base_analyze_method is not None: - result = base_analyze_method(*args, **kwargs) + result_type = base_analyze_method(*args, **kwargs) else: - result = expected_return + result_type = expected_return if len(attributes) > 0: - result = Annotated[result, *attributes] + result_type = Annotated[result_type, *attributes] - encoded_type = encode_enriched_type(result) + analyzed_result_type_info = analyze_type_info(result_type) + encoded_type = encode_enriched_type_info(analyzed_result_type_info) if potentially_missing_required_arg: encoded_type["nullable"] = True + + self._result_encoder = make_engine_value_encoder(analyzed_result_type_info) + return encoded_type async def prepare(self) -> None: @@ -343,7 +348,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: output = await self._acall(*decoded_args, **decoded_kwargs) else: output = await self._acall(*decoded_args, **decoded_kwargs) - return encode_engine_value(output, type_hint=expected_return) + return self._result_encoder(output) _WrappedClass.__name__ = executor_cls.__name__ _WrappedClass.__doc__ = executor_cls.__doc__ diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 77cc15990..902f18d02 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -2,7 +2,7 @@ import inspect import uuid from dataclasses import dataclass, make_dataclass -from typing import Annotated, Any, Callable, Literal, NamedTuple +from typing import Annotated, Any, Callable, Literal, NamedTuple, Type import numpy as np import pytest @@ -11,7 +11,7 @@ import cocoindex from cocoindex.convert import ( dump_engine_object, - encode_engine_value, + make_engine_value_encoder, make_engine_value_decoder, ) from cocoindex.typing import ( @@ -69,6 +69,14 @@ class CustomerNamedTuple(NamedTuple): tags: list[Tag] | None = None +def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any: + """ + Encode a Python value to an engine value. + """ + encoder = make_engine_value_encoder(analyze_type_info(type_hint)) + return encoder(value) + + def build_engine_value_decoder( engine_type_in_py: Any, python_type: Any | None = None ) -> Callable[[Any], Any]: