Skip to content

Commit 37e0670

Browse files
authored
test: add a test to cover batching custom function support (#1239)
1 parent 7904bce commit 37e0670

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

python/cocoindex/tests/test_transform_flow.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,66 @@ def transform_flow_with_analyze_prepare(
205205
result = transform_flow_with_analyze_prepare.eval("Hello")
206206
expected = "Hello world!!"
207207
assert result == expected, f"Expected {expected}, got {result}"
208+
209+
210+
# Test batching behavior.
211+
212+
213+
@cocoindex.op.function(batching=True)
214+
def batching_append_world(text: list[str]) -> list[str]:
215+
"""Append ' world' to the input text."""
216+
return [f"{t} world" for t in text]
217+
218+
219+
class batchingAppendSuffix(cocoindex.op.FunctionSpec):
220+
suffix: str
221+
222+
223+
@cocoindex.op.executor_class(batching=True)
224+
class batchingAppendSuffixExecutor:
225+
spec: batchingAppendSuffix
226+
227+
def __call__(self, text: list[str]) -> list[str]:
228+
return [f"{t}{self.spec.suffix}" for t in text]
229+
230+
231+
class batchingAppendSuffixWithAnalyzePrepare(cocoindex.op.FunctionSpec):
232+
suffix: str
233+
234+
235+
@cocoindex.op.executor_class(batching=True)
236+
class batchingAppendSuffixWithAnalyzePrepareExecutor:
237+
spec: batchingAppendSuffixWithAnalyzePrepare
238+
suffix: str
239+
240+
def analyze(self) -> Any:
241+
return str
242+
243+
def prepare(self) -> None:
244+
self.suffix = self.spec.suffix
245+
246+
def __call__(self, text: list[str]) -> list[str]:
247+
return [f"{t}{self.suffix}" for t in text]
248+
249+
250+
def test_batching_function() -> None:
251+
@cocoindex.transform_flow()
252+
def transform_flow(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]:
253+
return text.transform(batching_append_world).transform(
254+
batchingAppendSuffix(suffix="!")
255+
)
256+
257+
result = transform_flow.eval("Hello")
258+
expected = "Hello world!"
259+
assert result == expected, f"Expected {expected}, got {result}"
260+
261+
@cocoindex.transform_flow()
262+
def transform_flow_with_analyze_prepare(
263+
text: cocoindex.DataSlice[str],
264+
) -> cocoindex.DataSlice[str]:
265+
return text.transform(batching_append_world).transform(
266+
batchingAppendSuffixWithAnalyzePrepare(suffix="!!")
267+
)
268+
269+
result = transform_flow_with_analyze_prepare.eval("Hello")
270+
expected = "Hello world!!"

0 commit comments

Comments
 (0)