|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import asyncio |
8 | | -import re |
9 | | -import inspect |
10 | 8 | import datetime |
11 | 9 | import functools |
| 10 | +import inspect |
| 11 | +import re |
| 12 | + |
| 13 | +from dataclasses import dataclass |
| 14 | +from enum import Enum |
| 15 | +from threading import Lock |
12 | 16 | from typing import ( |
13 | 17 | Any, |
14 | 18 | Callable, |
| 19 | + Generic, |
| 20 | + NamedTuple, |
15 | 21 | Sequence, |
16 | 22 | TypeVar, |
17 | | - Generic, |
| 23 | + cast, |
18 | 24 | get_args, |
19 | 25 | get_origin, |
20 | | - NamedTuple, |
21 | | - cast, |
22 | 26 | Iterable, |
23 | 27 | ) |
24 | | -from threading import Lock |
25 | | -from enum import Enum |
26 | | -from dataclasses import dataclass |
| 28 | + |
27 | 29 | from rich.text import Text |
28 | 30 | from rich.tree import Tree |
29 | 31 |
|
|
32 | 34 | from . import op |
33 | 35 | from . import setting |
34 | 36 | from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder |
35 | | -from .typing import encode_enriched_type |
| 37 | +from .op import FunctionSpec |
36 | 38 | from .runtime import execution_context |
37 | 39 | from .setup import SetupChangeBundle |
| 40 | +from .typing import encode_enriched_type |
38 | 41 |
|
39 | 42 |
|
40 | 43 | class _NameBuilder: |
@@ -92,6 +95,30 @@ def _spec_kind(spec: Any) -> str: |
92 | 95 | return cast(str, spec.__class__.__name__) |
93 | 96 |
|
94 | 97 |
|
| 98 | +def _transform_helper( |
| 99 | + flow_builder_state: _FlowBuilderState, |
| 100 | + fn_spec: FunctionSpec, |
| 101 | + transform_args: list[tuple[Any, str | None]], |
| 102 | + name: str | None = None, |
| 103 | +) -> DataSlice[Any]: |
| 104 | + if not isinstance(fn_spec, FunctionSpec): |
| 105 | + raise ValueError("transform() can only be called on a CocoIndex function") |
| 106 | + |
| 107 | + return _create_data_slice( |
| 108 | + flow_builder_state, |
| 109 | + lambda target_scope, name: flow_builder_state.engine_flow_builder.transform( |
| 110 | + _spec_kind(fn_spec), |
| 111 | + dump_engine_object(fn_spec), |
| 112 | + transform_args, |
| 113 | + target_scope, |
| 114 | + flow_builder_state.field_name_builder.build_name( |
| 115 | + name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_" |
| 116 | + ), |
| 117 | + ), |
| 118 | + name, |
| 119 | + ) |
| 120 | + |
| 121 | + |
95 | 122 | T = TypeVar("T") |
96 | 123 | S = TypeVar("S") |
97 | 124 |
|
@@ -191,31 +218,19 @@ def transform( |
191 | 218 | """ |
192 | 219 | Apply a function to the data slice. |
193 | 220 | """ |
194 | | - if not isinstance(fn_spec, op.FunctionSpec): |
195 | | - raise ValueError("transform() can only be called on a CocoIndex function") |
196 | | - |
197 | | - transform_args: list[tuple[Any, str | None]] |
198 | | - transform_args = [(self._state.engine_data_slice, None)] |
| 221 | + transform_args: list[tuple[Any, str | None]] = [ |
| 222 | + (self._state.engine_data_slice, None) |
| 223 | + ] |
199 | 224 | transform_args += [ |
200 | 225 | (self._state.flow_builder_state.get_data_slice(v), None) for v in args |
201 | 226 | ] |
202 | 227 | transform_args += [ |
203 | 228 | (self._state.flow_builder_state.get_data_slice(v), k) |
204 | | - for (k, v) in kwargs.items() |
| 229 | + for k, v in kwargs.items() |
205 | 230 | ] |
206 | 231 |
|
207 | | - flow_builder_state = self._state.flow_builder_state |
208 | | - return _create_data_slice( |
209 | | - flow_builder_state, |
210 | | - lambda target_scope, name: flow_builder_state.engine_flow_builder.transform( |
211 | | - _spec_kind(fn_spec), |
212 | | - dump_engine_object(fn_spec), |
213 | | - transform_args, |
214 | | - target_scope, |
215 | | - flow_builder_state.field_name_builder.build_name( |
216 | | - name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_" |
217 | | - ), |
218 | | - ), |
| 232 | + return _transform_helper( |
| 233 | + self._state.flow_builder_state, fn_spec, transform_args |
219 | 234 | ) |
220 | 235 |
|
221 | 236 | def call(self, func: Callable[..., S], *args: Any, **kwargs: Any) -> S: |
@@ -446,6 +461,24 @@ def add_source( |
446 | 461 | name, |
447 | 462 | ) |
448 | 463 |
|
| 464 | + def transform( |
| 465 | + self, fn_spec: FunctionSpec, *args: Any, **kwargs: Any |
| 466 | + ) -> DataSlice[Any]: |
| 467 | + """ |
| 468 | + Apply a function to inputs, returning a DataSlice. |
| 469 | + """ |
| 470 | + transform_args: list[tuple[Any, str | None]] = [ |
| 471 | + (self._state.get_data_slice(v), None) for v in args |
| 472 | + ] |
| 473 | + transform_args += [ |
| 474 | + (self._state.get_data_slice(v), k) for k, v in kwargs.items() |
| 475 | + ] |
| 476 | + |
| 477 | + if not transform_args: |
| 478 | + raise ValueError("At least one input is required for transformation") |
| 479 | + |
| 480 | + return _transform_helper(self._state, fn_spec, transform_args) |
| 481 | + |
449 | 482 | def declare(self, spec: op.DeclarationSpec) -> None: |
450 | 483 | """ |
451 | 484 | Add a declaration to the flow. |
|
0 commit comments