Skip to content

Commit 7501bc5

Browse files
committed
feat: support transform for FlowBuilder
1 parent dfacef3 commit 7501bc5

File tree

3 files changed

+82
-109
lines changed

3 files changed

+82
-109
lines changed

python/cocoindex/flow.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,25 @@
55
from __future__ import annotations
66

77
import asyncio
8-
import re
9-
import inspect
108
import datetime
119
import functools
12-
10+
import inspect
11+
import re
12+
from dataclasses import dataclass
13+
from enum import Enum
14+
from threading import Lock
1315
from typing import (
1416
Any,
1517
Callable,
18+
Generic,
19+
NamedTuple,
1620
Sequence,
1721
TypeVar,
18-
Generic,
22+
cast,
1923
get_args,
2024
get_origin,
21-
NamedTuple,
22-
cast,
2325
)
24-
from threading import Lock
25-
from enum import Enum
26-
from dataclasses import dataclass
26+
2727
from rich.text import Text
2828
from rich.tree import Tree
2929

@@ -32,8 +32,9 @@
3232
from . import op
3333
from . import setting
3434
from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder
35-
from .typing import encode_enriched_type
35+
from .op import FunctionSpec
3636
from .runtime import execution_context
37+
from .typing import encode_enriched_type
3738

3839

3940
class _NameBuilder:
@@ -91,6 +92,30 @@ def _spec_kind(spec: Any) -> str:
9192
return cast(str, spec.__class__.__name__)
9293

9394

95+
def _transform_helper(
96+
flow_builder_state: _FlowBuilderState,
97+
fn_spec: FunctionSpec,
98+
transform_args: list[tuple[Any, str | None]],
99+
name: str | None = None,
100+
) -> DataSlice[Any]:
101+
if not isinstance(fn_spec, FunctionSpec):
102+
raise ValueError("transform() can only be called on a CocoIndex function")
103+
104+
return _create_data_slice(
105+
flow_builder_state,
106+
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
107+
_spec_kind(fn_spec),
108+
dump_engine_object(fn_spec),
109+
transform_args,
110+
target_scope,
111+
flow_builder_state.field_name_builder.build_name(
112+
name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
113+
),
114+
),
115+
name,
116+
)
117+
118+
94119
T = TypeVar("T")
95120
S = TypeVar("S")
96121

@@ -190,31 +215,19 @@ def transform(
190215
"""
191216
Apply a function to the data slice.
192217
"""
193-
if not isinstance(fn_spec, op.FunctionSpec):
194-
raise ValueError("transform() can only be called on a CocoIndex function")
195-
196-
transform_args: list[tuple[Any, str | None]]
197-
transform_args = [(self._state.engine_data_slice, None)]
218+
transform_args: list[tuple[Any, str | None]] = [
219+
(self._state.engine_data_slice, None)
220+
]
198221
transform_args += [
199222
(self._state.flow_builder_state.get_data_slice(v), None) for v in args
200223
]
201224
transform_args += [
202225
(self._state.flow_builder_state.get_data_slice(v), k)
203-
for (k, v) in kwargs.items()
226+
for k, v in kwargs.items()
204227
]
205228

206-
flow_builder_state = self._state.flow_builder_state
207-
return _create_data_slice(
208-
flow_builder_state,
209-
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
210-
_spec_kind(fn_spec),
211-
dump_engine_object(fn_spec),
212-
transform_args,
213-
target_scope,
214-
flow_builder_state.field_name_builder.build_name(
215-
name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_"
216-
),
217-
),
229+
return _transform_helper(
230+
self._state.flow_builder_state, fn_spec, transform_args
218231
)
219232

220233
def call(self, func: Callable[..., S], *args: Any, **kwargs: Any) -> S:
@@ -445,6 +458,24 @@ def add_source(
445458
name,
446459
)
447460

461+
def transform(
462+
self, fn_spec: FunctionSpec, *args: Any, **kwargs: Any
463+
) -> DataSlice[Any]:
464+
"""
465+
Apply a function to inputs, returning a DataSlice.
466+
"""
467+
transform_args: list[tuple[Any, str | None]] = [
468+
(self._state.get_data_slice(v), None) for v in args
469+
]
470+
transform_args += [
471+
(self._state.get_data_slice(v), k) for k, v in kwargs.items()
472+
]
473+
474+
if not transform_args:
475+
raise ValueError("At least one input is required for transformation")
476+
477+
return _transform_helper(self._state, fn_spec, transform_args)
478+
448479
def declare(self, spec: op.DeclarationSpec) -> None:
449480
"""
450481
Add a declaration to the flow.

python/cocoindex/functions.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,26 @@
11
"""All builtin functions."""
22

33
import dataclasses
4-
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeVar
4+
from typing import Annotated, Any, Literal
55

66
import numpy as np
77
from numpy.typing import NDArray
88

99
from . import llm, op
10-
from .flow import DataSlice
1110
from .typing import TypeAttr, Vector
1211

13-
# Libraries that are heavy to import. Lazily import them later.
14-
if TYPE_CHECKING:
15-
import sentence_transformers
12+
# Check if sentence_transformers is available
13+
try:
14+
import sentence_transformers # type: ignore
1615

17-
T = TypeVar("T")
16+
_SENTENCE_TRANSFORMERS_AVAILABLE = True
17+
except ImportError:
18+
_SENTENCE_TRANSFORMERS_AVAILABLE = False
1819

1920

2021
class ParseJson(op.FunctionSpec):
2122
"""Parse a text into a JSON object."""
2223

23-
def __call__(
24-
self, *, text: DataSlice[T], language: str | None = "json"
25-
) -> DataSlice[T]:
26-
return super().__call__(text=text, language=language)
27-
2824

2925
@dataclasses.dataclass
3026
class CustomLanguageSpec:
@@ -40,23 +36,6 @@ class SplitRecursively(op.FunctionSpec):
4036

4137
custom_languages: list[CustomLanguageSpec] = dataclasses.field(default_factory=list)
4238

43-
def __call__(
44-
self,
45-
*,
46-
text: DataSlice[T],
47-
chunk_size: int,
48-
min_chunk_size: int | None = None,
49-
chunk_overlap: int | None = None,
50-
language: DataSlice[T] | None = None,
51-
) -> DataSlice[T]:
52-
return super().__call__(
53-
text=text,
54-
chunk_size=chunk_size,
55-
language=language,
56-
min_chunk_size=min_chunk_size,
57-
chunk_overlap=chunk_overlap,
58-
)
59-
6039

6140
class EmbedText(op.FunctionSpec):
6241
"""Embed a text into a vector space."""
@@ -75,11 +54,6 @@ class ExtractByLlm(op.FunctionSpec):
7554
output_type: type
7655
instruction: str | None = None
7756

78-
def __call__(
79-
self, *, text: DataSlice[T] | None = None, image: DataSlice[T] | None = None
80-
) -> DataSlice[T]:
81-
return super().__call__(text=text, image=image)
82-
8357

8458
class SentenceTransformerEmbed(op.FunctionSpec):
8559
"""
@@ -89,6 +63,10 @@ class SentenceTransformerEmbed(op.FunctionSpec):
8963
9064
model: The name of the SentenceTransformer model to use.
9165
args: Additional arguments to pass to the SentenceTransformer constructor. e.g. {"trust_remote_code": True}
66+
67+
Note:
68+
This function requires the optional sentence-transformers dependency.
69+
Install it with: pip install 'cocoindex[embeddings]'
9270
"""
9371

9472
model: str
@@ -103,6 +81,14 @@ class SentenceTransformerEmbedExecutor:
10381
_model: "sentence_transformers.SentenceTransformer"
10482

10583
def analyze(self, text: Any) -> type:
84+
if not _SENTENCE_TRANSFORMERS_AVAILABLE:
85+
raise ImportError(
86+
"sentence_transformers is required for SentenceTransformerEmbed function. "
87+
"Install it with one of these commands:\n"
88+
" pip install 'cocoindex[embeddings]'\n"
89+
" pip install sentence-transformers"
90+
)
91+
10692
import sentence_transformers # pylint: disable=import-outside-toplevel
10793

10894
args = self.spec.args or {}

python/cocoindex/op.py

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from typing import Any, Awaitable, Callable, Protocol, dataclass_transform
1010

1111
from . import _engine # type: ignore
12-
from .convert import dump_engine_object, encode_engine_value, make_engine_value_decoder
13-
from .flow import DataSlice, _create_data_slice, _spec_kind, _to_snake_case
12+
from .convert import encode_engine_value, make_engine_value_decoder
1413
from .typing import encode_enriched_type, resolve_forward_ref
1514

1615

@@ -49,51 +48,7 @@ class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: dis
4948

5049

5150
class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint: disable=too-few-public-methods
52-
"""A function spec. Can be instantiated and called like a function: spec(args...).
53-
For non-chain-style calls, use spec(args...) with at least one DataSlice argument.
54-
For chain-style calls, use data_slice.transform(spec, args...).
55-
"""
56-
57-
def __call__(self, *args: Any, **kwargs: Any) -> DataSlice[Any]:
58-
"""Execute the function, returning a DataSlice."""
59-
60-
data_slice_args = [arg for arg in args if isinstance(arg, DataSlice)]
61-
data_slice_kwargs = {
62-
k: v for k, v in kwargs.items() if isinstance(v, DataSlice)
63-
}
64-
if not data_slice_args and not data_slice_kwargs:
65-
raise ValueError(
66-
"At least one DataSlice argument is required to provide flow context"
67-
)
68-
69-
first_data_slice = (
70-
data_slice_args[0]
71-
if data_slice_args
72-
else list(data_slice_kwargs.values())[0]
73-
)
74-
flow_builder_state = first_data_slice._state.flow_builder_state
75-
76-
transform_args: list[tuple[Any, str | None]] = [
77-
(flow_builder_state.get_data_slice(v), None) for v in args if v is not None
78-
]
79-
transform_args += [
80-
(flow_builder_state.get_data_slice(v), k)
81-
for k, v in kwargs.items()
82-
if v is not None
83-
]
84-
85-
return _create_data_slice(
86-
flow_builder_state,
87-
lambda target_scope, name: flow_builder_state.engine_flow_builder.transform(
88-
_spec_kind(self),
89-
dump_engine_object(self),
90-
transform_args,
91-
target_scope,
92-
flow_builder_state.field_name_builder.build_name(
93-
name, prefix=_to_snake_case(_spec_kind(self)) + "_"
94-
),
95-
),
96-
)
51+
"""A function spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)"""
9752

9853

9954
class TargetSpec(metaclass=SpecMeta, category=OpCategory.TARGET): # pylint: disable=too-few-public-methods
@@ -355,7 +310,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
355310
return fn(*args, **kwargs)
356311

357312
class _Spec(FunctionSpec):
358-
pass
313+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
314+
return fn(*args, **kwargs)
359315

360316
_Spec.__name__ = op_name
361317
_Spec.__doc__ = fn.__doc__

0 commit comments

Comments
 (0)