diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 5d1de9a1..605e771d 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -5,25 +5,27 @@ from __future__ import annotations import asyncio -import re -import inspect import datetime import functools +import inspect +import re + +from dataclasses import dataclass +from enum import Enum +from threading import Lock from typing import ( Any, Callable, + Generic, + NamedTuple, Sequence, TypeVar, - Generic, + cast, get_args, get_origin, - NamedTuple, - cast, Iterable, ) -from threading import Lock -from enum import Enum -from dataclasses import dataclass + from rich.text import Text from rich.tree import Tree @@ -32,9 +34,10 @@ from . import op from . import setting from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder -from .typing import encode_enriched_type +from .op import FunctionSpec from .runtime import execution_context from .setup import SetupChangeBundle +from .typing import encode_enriched_type class _NameBuilder: @@ -92,6 +95,30 @@ def _spec_kind(spec: Any) -> str: return cast(str, spec.__class__.__name__) +def _transform_helper( + flow_builder_state: _FlowBuilderState, + fn_spec: FunctionSpec, + transform_args: list[tuple[Any, str | None]], + name: str | None = None, +) -> DataSlice[Any]: + if not isinstance(fn_spec, FunctionSpec): + raise ValueError("transform() can only be called on a CocoIndex function") + + return _create_data_slice( + flow_builder_state, + lambda target_scope, name: flow_builder_state.engine_flow_builder.transform( + _spec_kind(fn_spec), + dump_engine_object(fn_spec), + transform_args, + target_scope, + flow_builder_state.field_name_builder.build_name( + name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_" + ), + ), + name, + ) + + T = TypeVar("T") S = TypeVar("S") @@ -191,31 +218,19 @@ def transform( """ Apply a function to the data slice. """ - if not isinstance(fn_spec, op.FunctionSpec): - raise ValueError("transform() can only be called on a CocoIndex function") - - transform_args: list[tuple[Any, str | None]] - transform_args = [(self._state.engine_data_slice, None)] + transform_args: list[tuple[Any, str | None]] = [ + (self._state.engine_data_slice, None) + ] transform_args += [ (self._state.flow_builder_state.get_data_slice(v), None) for v in args ] transform_args += [ (self._state.flow_builder_state.get_data_slice(v), k) - for (k, v) in kwargs.items() + for k, v in kwargs.items() ] - flow_builder_state = self._state.flow_builder_state - return _create_data_slice( - flow_builder_state, - lambda target_scope, name: flow_builder_state.engine_flow_builder.transform( - _spec_kind(fn_spec), - dump_engine_object(fn_spec), - transform_args, - target_scope, - flow_builder_state.field_name_builder.build_name( - name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_" - ), - ), + return _transform_helper( + self._state.flow_builder_state, fn_spec, transform_args ) def call(self, func: Callable[..., S], *args: Any, **kwargs: Any) -> S: @@ -446,6 +461,24 @@ def add_source( name, ) + def transform( + self, fn_spec: FunctionSpec, *args: Any, **kwargs: Any + ) -> DataSlice[Any]: + """ + Apply a function to inputs, returning a DataSlice. + """ + transform_args: list[tuple[Any, str | None]] = [ + (self._state.get_data_slice(v), None) for v in args + ] + transform_args += [ + (self._state.get_data_slice(v), k) for k, v in kwargs.items() + ] + + if not transform_args: + raise ValueError("At least one input is required for transformation") + + return _transform_helper(self._state, fn_spec, transform_args) + def declare(self, spec: op.DeclarationSpec) -> None: """ Add a declaration to the flow. diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 140d6fd8..a53e337d 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -5,13 +5,12 @@ import asyncio import dataclasses import inspect - -from typing import Protocol, Any, Callable, Awaitable, dataclass_transform from enum import Enum +from typing import Any, Awaitable, Callable, Protocol, dataclass_transform -from .typing import encode_enriched_type, resolve_forward_ref -from .convert import encode_engine_value, make_engine_value_decoder from . import _engine # type: ignore +from .convert import encode_engine_value, make_engine_value_decoder +from .typing import encode_enriched_type, resolve_forward_ref class OpCategory(Enum):