Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 12 additions & 39 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,21 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool
ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty)


def _make_encoder_closure(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
def make_engine_value_encoder(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
"""
Create an encoder closure for a specific type.
"""
variant = type_info.variant

if isinstance(variant, AnalyzedUnknownType):
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")

if isinstance(variant, AnalyzedListType):
elem_type_info = (
analyze_type_info(variant.elem_type) if variant.elem_type else ANY_TYPE_INFO
)
if isinstance(elem_type_info.variant, AnalyzedStructType):
elem_encoder = _make_encoder_closure(elem_type_info)
elem_encoder = make_engine_value_encoder(elem_type_info)

def encode_struct_list(value: Any) -> Any:
return None if value is None else [elem_encoder(v) for v in value]
Expand All @@ -104,11 +107,13 @@ def encode_struct_dict(value: Any) -> Any:
# Handle KTable case
if value and is_struct_type(val_type):
key_encoder = (
_make_encoder_closure(analyze_type_info(key_type))
make_engine_value_encoder(analyze_type_info(key_type))
if is_struct_type(key_type)
else _make_encoder_closure(ANY_TYPE_INFO)
else make_engine_value_encoder(ANY_TYPE_INFO)
)
value_encoder = make_engine_value_encoder(
analyze_type_info(val_type)
)
value_encoder = _make_encoder_closure(analyze_type_info(val_type))
return [
[key_encoder(k)] + value_encoder(v) for k, v in value.items()
]
Expand All @@ -122,7 +127,7 @@ def encode_struct_dict(value: Any) -> Any:
if dataclasses.is_dataclass(struct_type):
fields = dataclasses.fields(struct_type)
field_encoders = [
_make_encoder_closure(analyze_type_info(f.type)) for f in fields
make_engine_value_encoder(analyze_type_info(f.type)) for f in fields
]
field_names = [f.name for f in fields]

Expand All @@ -140,7 +145,7 @@ def encode_dataclass(value: Any) -> Any:
annotations = struct_type.__annotations__
field_names = list(getattr(struct_type, "_fields", ()))
field_encoders = [
_make_encoder_closure(
make_engine_value_encoder(
analyze_type_info(annotations[name])
if name in annotations
else ANY_TYPE_INFO
Expand Down Expand Up @@ -170,38 +175,6 @@ def encode_basic_value(value: Any) -> Any:
return encode_basic_value


def make_engine_value_encoder(type_hint: Type[Any] | str) -> Callable[[Any], Any]:
"""
Create an encoder closure for converting Python values to engine values.

Args:
type_hint: Type annotation for the values to encode

Returns:
A closure that encodes Python values to engine values
"""
type_info = analyze_type_info(type_hint)
if isinstance(type_info.variant, AnalyzedUnknownType):
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")

return _make_encoder_closure(type_info)


def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
"""
Encode a Python value to an engine value.

Args:
value: The Python value to encode
type_hint: Type annotation for the value. This should always be provided.

Returns:
The encoded engine value
"""
encoder = make_engine_value_encoder(type_hint)
return encoder(value)


def make_engine_value_decoder(
field_path: list[str],
src_type: dict[str, Any],
Expand Down
104 changes: 47 additions & 57 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
from . import index
from . import op
from . import setting
from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder
from .convert import (
dump_engine_object,
make_engine_value_decoder,
make_engine_value_encoder,
)
from .op import FunctionSpec
from .runtime import execution_context
from .setup import SetupChangeBundle
Expand Down Expand Up @@ -974,33 +978,60 @@ class TransformFlowInfo(NamedTuple):
result_decoder: Callable[[Any], T]


class FlowArgInfo(NamedTuple):
name: str
type_hint: Any
encoder: Callable[[Any], Any]


class TransformFlow(Generic[T]):
"""
A transient transformation flow that transforms in-memory data.
"""

_flow_fn: Callable[..., DataSlice[T]]
_flow_name: str
_flow_arg_types: list[Any]
_param_names: list[str]
_args_info: list[FlowArgInfo]

_lazy_lock: asyncio.Lock
_lazy_flow_info: TransformFlowInfo | None = None

def __init__(
self,
flow_fn: Callable[..., DataSlice[T]],
flow_arg_types: Sequence[Any],
/,
name: str | None = None,
):
self._flow_fn = flow_fn
self._flow_name = _transform_flow_name_builder.build_name(
name, prefix="_transform_flow_"
)
self._flow_arg_types = list(flow_arg_types)
self._lazy_lock = asyncio.Lock()

sig = inspect.signature(flow_fn)
args_info = []
for param_name, param in sig.parameters.items():
if param.kind not in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
):
raise ValueError(
f"Parameter `{param_name}` is not a parameter can be passed by name"
)
value_type_annotation: type | None = _get_data_slice_annotation_type(
param.annotation
)
if value_type_annotation is None:
raise ValueError(
f"Parameter `{param_name}` for {flow_fn} has no value type annotation. "
"Please use `cocoindex.DataSlice[T]` where T is the type of the value."
)
encoder = make_engine_value_encoder(
analyze_type_info(value_type_annotation)
)
args_info.append(FlowArgInfo(param_name, value_type_annotation, encoder))
self._args_info = args_info

def __call__(self, *args: Any, **kwargs: Any) -> DataSlice[T]:
return self._flow_fn(*args, **kwargs)

Expand All @@ -1020,31 +1051,15 @@ async def _flow_info_async(self) -> TransformFlowInfo:

async def _build_flow_info_async(self) -> TransformFlowInfo:
flow_builder_state = _FlowBuilderState(self._flow_name)
sig = inspect.signature(self._flow_fn)
if len(sig.parameters) != len(self._flow_arg_types):
raise ValueError(
f"Number of parameters in the flow function ({len(sig.parameters)}) "
f"does not match the number of argument types ({len(self._flow_arg_types)})"
)

kwargs: dict[str, DataSlice[T]] = {}
for (param_name, param), param_type in zip(
sig.parameters.items(), self._flow_arg_types
):
if param.kind not in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
):
raise ValueError(
f"Parameter `{param_name}` is not a parameter can be passed by name"
)
encoded_type = encode_enriched_type(param_type)
for arg_info in self._args_info:
encoded_type = encode_enriched_type(arg_info.type_hint)
if encoded_type is None:
raise ValueError(f"Parameter `{param_name}` has no type annotation")
raise ValueError(f"Parameter `{arg_info.name}` has no type annotation")
engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
param_name, encoded_type
arg_info.name, encoded_type
)
kwargs[param_name] = DataSlice(
kwargs[arg_info.name] = DataSlice(
_DataSliceState(flow_builder_state, engine_ds)
)

Expand All @@ -1057,13 +1072,12 @@ async def _build_flow_info_async(self) -> TransformFlowInfo:
execution_context.event_loop
)
)
self._param_names = list(sig.parameters.keys())

engine_return_type = (
_data_slice_state(output).engine_data_slice.data_type().schema()
)
python_return_type: type[T] | None = _get_data_slice_annotation_type(
sig.return_annotation
inspect.signature(self._flow_fn).return_annotation
)
result_decoder = make_engine_value_decoder(
[], engine_return_type["type"], analyze_type_info(python_return_type)
Expand Down Expand Up @@ -1095,18 +1109,14 @@ async def eval_async(self, *args: Any, **kwargs: Any) -> T:
"""
flow_info = await self._flow_info_async()
params = []
for i, (arg, arg_type) in enumerate(
zip(self._param_names, self._flow_arg_types)
):
param_type = (
self._flow_arg_types[i] if i < len(self._flow_arg_types) else Any
)
for i, arg_info in enumerate(self._args_info):
if i < len(args):
params.append(encode_engine_value(args[i], type_hint=param_type))
arg = args[i]
elif arg in kwargs:
params.append(encode_engine_value(kwargs[arg], type_hint=param_type))
arg = kwargs[arg]
else:
raise ValueError(f"Parameter {arg} is not provided")
params.append(arg_info.encoder(arg))
engine_result = await flow_info.engine_flow.evaluate_async(params)
return flow_info.result_decoder(engine_result)

Expand All @@ -1117,27 +1127,7 @@ def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T]
"""

def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]) -> TransformFlow[T]:
sig = inspect.signature(fn)
arg_types = []
for param_name, param in sig.parameters.items():
if param.kind not in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
):
raise ValueError(
f"Parameter `{param_name}` is not a parameter can be passed by name"
)
value_type_annotation: type[T] | None = _get_data_slice_annotation_type(
param.annotation
)
if value_type_annotation is None:
raise ValueError(
f"Parameter `{param_name}` for {fn} has no value type annotation. "
"Please use `cocoindex.DataSlice[T]` where T is the type of the value."
)
arg_types.append(value_type_annotation)

_transform_flow = TransformFlow(fn, arg_types)
_transform_flow = TransformFlow(fn)
functools.update_wrapper(_transform_flow, fn)
return _transform_flow

Expand Down
19 changes: 12 additions & 7 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

from . import _engine # type: ignore
from .convert import (
encode_engine_value,
make_engine_value_encoder,
make_engine_value_decoder,
make_engine_struct_decoder,
)
from .typing import (
TypeAttr,
encode_enriched_type,
encode_enriched_type_info,
resolve_forward_ref,
analyze_type_info,
AnalyzedAnyType,
Expand Down Expand Up @@ -185,6 +185,7 @@ class _WrappedClass(executor_cls, _Fallback): # type: ignore[misc]
_args_info: list[_ArgInfo]
_kwargs_info: dict[str, _ArgInfo]
_acall: Callable[..., Awaitable[Any]]
_result_encoder: Callable[[Any], Any]

def __init__(self, spec: Any) -> None:
super().__init__()
Expand Down Expand Up @@ -295,15 +296,19 @@ def process_arg(

base_analyze_method = getattr(self, "analyze", None)
if base_analyze_method is not None:
result = base_analyze_method(*args, **kwargs)
result_type = base_analyze_method(*args, **kwargs)
else:
result = expected_return
result_type = expected_return
if len(attributes) > 0:
result = Annotated[result, *attributes]
result_type = Annotated[result_type, *attributes]

encoded_type = encode_enriched_type(result)
analyzed_result_type_info = analyze_type_info(result_type)
encoded_type = encode_enriched_type_info(analyzed_result_type_info)
if potentially_missing_required_arg:
encoded_type["nullable"] = True

self._result_encoder = make_engine_value_encoder(analyzed_result_type_info)

return encoded_type

async def prepare(self) -> None:
Expand Down Expand Up @@ -343,7 +348,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
output = await self._acall(*decoded_args, **decoded_kwargs)
else:
output = await self._acall(*decoded_args, **decoded_kwargs)
return encode_engine_value(output, type_hint=expected_return)
return self._result_encoder(output)

_WrappedClass.__name__ = executor_cls.__name__
_WrappedClass.__doc__ = executor_cls.__doc__
Expand Down
12 changes: 10 additions & 2 deletions python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import uuid
from dataclasses import dataclass, make_dataclass
from typing import Annotated, Any, Callable, Literal, NamedTuple
from typing import Annotated, Any, Callable, Literal, NamedTuple, Type

import numpy as np
import pytest
Expand All @@ -11,7 +11,7 @@
import cocoindex
from cocoindex.convert import (
dump_engine_object,
encode_engine_value,
make_engine_value_encoder,
make_engine_value_decoder,
)
from cocoindex.typing import (
Expand Down Expand Up @@ -69,6 +69,14 @@ class CustomerNamedTuple(NamedTuple):
tags: list[Tag] | None = None


def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
"""
Encode a Python value to an engine value.
"""
encoder = make_engine_value_encoder(analyze_type_info(type_hint))
return encoder(value)


def build_engine_value_decoder(
engine_type_in_py: Any, python_type: Any | None = None
) -> Callable[[Any], Any]:
Expand Down
Loading