Skip to content

Commit 409cd8c

Browse files
committed
test: update to correctly process child rows and return expected structure
1 parent 2bbb30d commit 409cd8c

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

python/cocoindex/tests/test_transform_flow.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing
22
from dataclasses import dataclass
3+
from typing import Any
34

45
import pytest
56

@@ -36,15 +37,15 @@ def simple_transform(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]
3637

3738

3839
@cocoindex.op.function()
39-
def extract_value(child: int) -> int:
40-
"""Extracts the value from a Child object."""
41-
return child
40+
def extract_value(value: int) -> int:
41+
"""Extracts the value."""
42+
return value
4243

4344

4445
@cocoindex.transform_flow()
4546
def for_each_transform(
4647
data: cocoindex.DataSlice[Parent],
47-
) -> cocoindex.DataSlice[Parent]:
48+
) -> cocoindex.DataSlice[Any]:
4849
"""Transform flow that processes child rows to extract values."""
4950
with data["children"].row() as child:
5051
child["new_field"] = child["value"].transform(extract_value)
@@ -73,18 +74,30 @@ def test_for_each_transform_flow() -> None:
7374
"""Test the complex transform flow with child rows."""
7475
input_data = Parent(children=[Child(1), Child(2), Child(3)])
7576
result = for_each_transform.eval(input_data)
76-
expected = Parent(children=[Child(1), Child(2), Child(3)])
77+
expected = {
78+
"children": [
79+
{"value": 1, "new_field": 1},
80+
{"value": 2, "new_field": 2},
81+
{"value": 3, "new_field": 3},
82+
]
83+
}
7784
assert result == expected, f"Expected {expected}, got {result}"
7885

7986
input_data = Parent(children=[])
8087
result = for_each_transform.eval(input_data)
81-
assert result == Parent(children=[]), f"Expected [], got {result}"
88+
assert result == {"children": []}, f"Expected {{'children': []}}, got {result}"
8289

8390

8491
@pytest.mark.asyncio
8592
async def test_for_each_transform_flow_async() -> None:
8693
"""Test the complex transform flow asynchronously."""
8794
input_data = Parent(children=[Child(4), Child(5)])
8895
result = await for_each_transform.eval_async(input_data)
89-
expected = Parent(children=[Child(4), Child(5)])
96+
expected = {
97+
"children": [
98+
{"value": 4, "new_field": 4},
99+
{"value": 5, "new_field": 5},
100+
]
101+
}
102+
90103
assert result == expected, f"Expected {expected}, got {result}"

0 commit comments

Comments
 (0)