|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import asyncio |
8 | | -import re |
9 | | -import inspect |
10 | 8 | import datetime |
11 | 9 | import functools |
12 | | - |
| 10 | +import inspect |
| 11 | +import re |
| 12 | +from dataclasses import dataclass |
| 13 | +from enum import Enum |
| 14 | +from threading import Lock |
13 | 15 | from typing import ( |
14 | 16 | Any, |
15 | 17 | Callable, |
| 18 | + Generic, |
| 19 | + NamedTuple, |
16 | 20 | Sequence, |
17 | 21 | TypeVar, |
18 | | - Generic, |
| 22 | + cast, |
19 | 23 | get_args, |
20 | 24 | get_origin, |
21 | | - NamedTuple, |
22 | | - cast, |
23 | 25 | ) |
24 | | -from threading import Lock |
25 | | -from enum import Enum |
26 | | -from dataclasses import dataclass |
| 26 | + |
27 | 27 | from rich.text import Text |
28 | 28 | from rich.tree import Tree |
29 | 29 |
|
|
32 | 32 | from . import op |
33 | 33 | from . import setting |
34 | 34 | from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder |
35 | | -from .typing import encode_enriched_type |
| 35 | +from .op import FunctionSpec |
36 | 36 | from .runtime import execution_context |
| 37 | +from .typing import encode_enriched_type |
37 | 38 |
|
38 | 39 |
|
39 | 40 | class _NameBuilder: |
@@ -91,6 +92,30 @@ def _spec_kind(spec: Any) -> str: |
91 | 92 | return cast(str, spec.__class__.__name__) |
92 | 93 |
|
93 | 94 |
|
| 95 | +def _transform_helper( |
| 96 | + flow_builder_state: _FlowBuilderState, |
| 97 | + fn_spec: FunctionSpec, |
| 98 | + transform_args: list[tuple[Any, str | None]], |
| 99 | + name: str | None = None, |
| 100 | +) -> DataSlice[Any]: |
| 101 | + if not isinstance(fn_spec, FunctionSpec): |
| 102 | + raise ValueError("transform() can only be called on a CocoIndex function") |
| 103 | + |
| 104 | + return _create_data_slice( |
| 105 | + flow_builder_state, |
| 106 | + lambda target_scope, name: flow_builder_state.engine_flow_builder.transform( |
| 107 | + _spec_kind(fn_spec), |
| 108 | + dump_engine_object(fn_spec), |
| 109 | + transform_args, |
| 110 | + target_scope, |
| 111 | + flow_builder_state.field_name_builder.build_name( |
| 112 | + name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_" |
| 113 | + ), |
| 114 | + ), |
| 115 | + name, |
| 116 | + ) |
| 117 | + |
| 118 | + |
94 | 119 | T = TypeVar("T") |
95 | 120 | S = TypeVar("S") |
96 | 121 |
|
@@ -190,31 +215,19 @@ def transform( |
190 | 215 | """ |
191 | 216 | Apply a function to the data slice. |
192 | 217 | """ |
193 | | - if not isinstance(fn_spec, op.FunctionSpec): |
194 | | - raise ValueError("transform() can only be called on a CocoIndex function") |
195 | | - |
196 | | - transform_args: list[tuple[Any, str | None]] |
197 | | - transform_args = [(self._state.engine_data_slice, None)] |
| 218 | + transform_args: list[tuple[Any, str | None]] = [ |
| 219 | + (self._state.engine_data_slice, None) |
| 220 | + ] |
198 | 221 | transform_args += [ |
199 | 222 | (self._state.flow_builder_state.get_data_slice(v), None) for v in args |
200 | 223 | ] |
201 | 224 | transform_args += [ |
202 | 225 | (self._state.flow_builder_state.get_data_slice(v), k) |
203 | | - for (k, v) in kwargs.items() |
| 226 | + for k, v in kwargs.items() |
204 | 227 | ] |
205 | 228 |
|
206 | | - flow_builder_state = self._state.flow_builder_state |
207 | | - return _create_data_slice( |
208 | | - flow_builder_state, |
209 | | - lambda target_scope, name: flow_builder_state.engine_flow_builder.transform( |
210 | | - _spec_kind(fn_spec), |
211 | | - dump_engine_object(fn_spec), |
212 | | - transform_args, |
213 | | - target_scope, |
214 | | - flow_builder_state.field_name_builder.build_name( |
215 | | - name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_" |
216 | | - ), |
217 | | - ), |
| 229 | + return _transform_helper( |
| 230 | + self._state.flow_builder_state, fn_spec, transform_args |
218 | 231 | ) |
219 | 232 |
|
220 | 233 | def call(self, func: Callable[..., S], *args: Any, **kwargs: Any) -> S: |
@@ -445,6 +458,24 @@ def add_source( |
445 | 458 | name, |
446 | 459 | ) |
447 | 460 |
|
| 461 | + def transform( |
| 462 | + self, fn_spec: FunctionSpec, *args: Any, **kwargs: Any |
| 463 | + ) -> DataSlice[Any]: |
| 464 | + """ |
| 465 | + Apply a function to inputs, returning a DataSlice. |
| 466 | + """ |
| 467 | + transform_args: list[tuple[Any, str | None]] = [ |
| 468 | + (self._state.get_data_slice(v), None) for v in args |
| 469 | + ] |
| 470 | + transform_args += [ |
| 471 | + (self._state.get_data_slice(v), k) for k, v in kwargs.items() |
| 472 | + ] |
| 473 | + |
| 474 | + if not transform_args: |
| 475 | + raise ValueError("At least one input is required for transformation") |
| 476 | + |
| 477 | + return _transform_helper(self._state, fn_spec, transform_args) |
| 478 | + |
448 | 479 | def declare(self, spec: op.DeclarationSpec) -> None: |
449 | 480 | """ |
450 | 481 | Add a declaration to the flow. |
|
0 commit comments