Skip to content

Commit 000fae8

Browse files
authored
feat: add transform method to FlowBuilder (#675)
* feat: support function calling in non-chain style * feat: support transform for FlowBuilder
1 parent 0294f46 commit 000fae8

File tree

2 files changed

+63
-31
lines changed

2 files changed

+63
-31
lines changed

python/cocoindex/flow.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,27 @@
55
from __future__ import annotations
66

77
import asyncio
8-
import re
9-
import inspect
108
import datetime
119
import functools
10+
import inspect
11+
import re
12+
13+
from dataclasses import dataclass
14+
from enum import Enum
15+
from threading import Lock
1216
from typing import (
1317
Any,
1418
Callable,
19+
Generic,
20+
NamedTuple,
1521
Sequence,
1622
TypeVar,
17-
Generic,
23+
cast,
1824
get_args,
1925
get_origin,
20-
NamedTuple,
21-
cast,
2226
Iterable,
2327
)
24-
from threading import Lock
25-
from enum import Enum
26-
from dataclasses import dataclass
28+
2729
from rich.text import Text
2830
from rich.tree import Tree
2931

@@ -32,9 +34,10 @@
3234
from . import op
3335
from . import setting
3436
from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder
35-
from .typing import encode_enriched_type
37+
from .op import FunctionSpec
3638
from .runtime import execution_context
3739
from .setup import SetupChangeBundle
40+
from .typing import encode_enriched_type
3841

3942

4043
class _NameBuilder:
@@ -92,6 +95,30 @@ def _spec_kind(spec: Any) -> str:
9295
return cast(str, spec.__class__.__name__)
9396

9497

98+
def _transform_helper(
99+
flow_builder_state: _FlowBuilderState,
100+
fn_spec: FunctionSpec,
101+
transform_args: list[tuple[Any, str | None]],
102+
name: str | None = None,
103+
) -> DataSlice[Any]:
104+
if not isinstance(fn_spec, FunctionSpec):
105+
raise ValueError("transform() can only be called on a CocoIndex function")
106+
107+
return _create_data_slice(
108+
flow_builder_state,
109+
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
110+
_spec_kind(fn_spec),
111+
dump_engine_object(fn_spec),
112+
transform_args,
113+
target_scope,
114+
flow_builder_state.field_name_builder.build_name(
115+
name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
116+
),
117+
),
118+
name,
119+
)
120+
121+
95122
T = TypeVar("T")
96123
S = TypeVar("S")
97124

@@ -191,31 +218,19 @@ def transform(
191218
"""
192219
Apply a function to the data slice.
193220
"""
194-
if not isinstance(fn_spec, op.FunctionSpec):
195-
raise ValueError("transform() can only be called on a CocoIndex function")
196-
197-
transform_args: list[tuple[Any, str | None]]
198-
transform_args = [(self._state.engine_data_slice, None)]
221+
transform_args: list[tuple[Any, str | None]] = [
222+
(self._state.engine_data_slice, None)
223+
]
199224
transform_args += [
200225
(self._state.flow_builder_state.get_data_slice(v), None) for v in args
201226
]
202227
transform_args += [
203228
(self._state.flow_builder_state.get_data_slice(v), k)
204-
for (k, v) in kwargs.items()
229+
for k, v in kwargs.items()
205230
]
206231

207-
flow_builder_state = self._state.flow_builder_state
208-
return _create_data_slice(
209-
flow_builder_state,
210-
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
211-
_spec_kind(fn_spec),
212-
dump_engine_object(fn_spec),
213-
transform_args,
214-
target_scope,
215-
flow_builder_state.field_name_builder.build_name(
216-
name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
217-
),
218-
),
232+
return _transform_helper(
233+
self._state.flow_builder_state, fn_spec, transform_args
219234
)
220235

221236
def call(self, func: Callable[..., S], *args: Any, **kwargs: Any) -> S:
@@ -446,6 +461,24 @@ def add_source(
446461
name,
447462
)
448463

464+
def transform(
465+
self, fn_spec: FunctionSpec, *args: Any, **kwargs: Any
466+
) -> DataSlice[Any]:
467+
"""
468+
Apply a function to inputs, returning a DataSlice.
469+
"""
470+
transform_args: list[tuple[Any, str | None]] = [
471+
(self._state.get_data_slice(v), None) for v in args
472+
]
473+
transform_args += [
474+
(self._state.get_data_slice(v), k) for k, v in kwargs.items()
475+
]
476+
477+
if not transform_args:
478+
raise ValueError("At least one input is required for transformation")
479+
480+
return _transform_helper(self._state, fn_spec, transform_args)
481+
449482
def declare(self, spec: op.DeclarationSpec) -> None:
450483
"""
451484
Add a declaration to the flow.

python/cocoindex/op.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import asyncio
66
import dataclasses
77
import inspect
8-
9-
from typing import Protocol, Any, Callable, Awaitable, dataclass_transform
108
from enum import Enum
9+
from typing import Any, Awaitable, Callable, Protocol, dataclass_transform
1110

12-
from .typing import encode_enriched_type, resolve_forward_ref
13-
from .convert import encode_engine_value, make_engine_value_decoder
1411
from . import _engine # type: ignore
12+
from .convert import encode_engine_value, make_engine_value_decoder
13+
from .typing import encode_enriched_type, resolve_forward_ref
1514

1615

1716
class OpCategory(Enum):

0 commit comments

Comments
 (0)