Skip to content

Commit 0fbfd3f

Browse files
authored
Add unit tests for Engine to Python format conversion (#328)
* Add unit tests for Engine to Python format conversion * Refactor tests: Add build_converter for engine-to-python * test: add mismatch tests * test: add positional tests * test: refactor and parameterize for clarity
1 parent b5ba1b1 commit 0fbfd3f

File tree

1 file changed

+169
-4
lines changed

1 file changed

+169
-4
lines changed

python/cocoindex/tests/test_convert.py

Lines changed: 169 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
import dataclasses
22
import uuid
33
import datetime
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, make_dataclass
55
import pytest
6+
from cocoindex.typing import encode_enriched_type
67
from cocoindex.convert import to_engine_value
8+
from cocoindex.convert import make_engine_value_converter
79

810
@dataclass
911
class Order:
1012
order_id: str
1113
name: str
1214
price: float
15+
extra_field: str = "default_extra"
16+
17+
@dataclass
18+
class Tag:
19+
name: str
1320

1421
@dataclass
1522
class Basket:
@@ -19,6 +26,21 @@ class Basket:
1926
class Customer:
2027
name: str
2128
order: Order
29+
tags: list[Tag] = None
30+
31+
@dataclass
32+
class NestedStruct:
33+
customer: Customer
34+
orders: list[Order]
35+
count: int = 0
36+
37+
def build_engine_value_converter(engine_type_in_py, python_type=None):
38+
"""
39+
Helper to build a converter for the given engine-side type (as represented in Python).
40+
If python_type is not specified, uses engine_type_in_py as the target.
41+
"""
42+
engine_type = encode_enriched_type(engine_type_in_py)["type"]
43+
return make_engine_value_converter([], engine_type, python_type or engine_type_in_py)
2244

2345
def test_to_engine_value_basic_types():
2446
assert to_engine_value(123) == 123
@@ -40,19 +62,19 @@ def test_to_engine_value_date_time_types():
4062

4163
def test_to_engine_value_struct():
4264
order = Order(order_id="O123", name="mixed nuts", price=25.0)
43-
assert to_engine_value(order) == ["O123", "mixed nuts", 25.0]
65+
assert to_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
4466

4567
def test_to_engine_value_list_of_structs():
4668
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
47-
assert to_engine_value(orders) == [["O1", "item1", 10.0], ["O2", "item2", 20.0]]
69+
assert to_engine_value(orders) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
4870

4971
def test_to_engine_value_struct_with_list():
5072
basket = Basket(items=["apple", "banana"])
5173
assert to_engine_value(basket) == [["apple", "banana"]]
5274

5375
def test_to_engine_value_nested_struct():
5476
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
55-
assert to_engine_value(customer) == ["Alice", ["O1", "item1", 10.0]]
77+
assert to_engine_value(customer) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None]
5678

5779
def test_to_engine_value_empty_list():
5880
assert to_engine_value([]) == []
@@ -67,3 +89,146 @@ def test_to_engine_value_tuple():
6789

6890
def test_to_engine_value_none():
6991
assert to_engine_value(None) is None
92+
93+
def test_make_engine_value_converter_basic_types():
94+
for engine_type_in_py, value in [
95+
(int, 42),
96+
(float, 3.14),
97+
(str, "hello"),
98+
(bool, True),
99+
# (type(None), None), # Removed unsupported NoneType
100+
]:
101+
converter = build_engine_value_converter(engine_type_in_py)
102+
assert converter(value) == value
103+
104+
@pytest.mark.parametrize(
105+
"converter_type, engine_val, expected",
106+
[
107+
# All fields match
108+
(Order, ["O123", "mixed nuts", 25.0, "default_extra"], Order("O123", "mixed nuts", 25.0, "default_extra")),
109+
# Extra field in engine value (should ignore extra)
110+
(Order, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], Order("O123", "mixed nuts", 25.0, "default_extra")),
111+
# Fewer fields in engine value (should fill with default)
112+
(Order, ["O123", "mixed nuts", 0.0, "default_extra"], Order("O123", "mixed nuts", 0.0, "default_extra")),
113+
# More fields in engine value (should ignore extra)
114+
(Order, ["O123", "mixed nuts", 25.0, "unexpected"], Order("O123", "mixed nuts", 25.0, "unexpected")),
115+
# Truly extra field (should ignore the fifth field)
116+
(Order, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], Order("O123", "mixed nuts", 25.0, "default_extra")),
117+
# Missing optional field in engine value (tags=None)
118+
(Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)),
119+
# Extra field in engine value for Customer (should ignore)
120+
(Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")])),
121+
]
122+
)
123+
def test_struct_conversion_cases(converter_type, engine_val, expected):
124+
converter = build_engine_value_converter(converter_type)
125+
assert converter(engine_val) == expected
126+
127+
def test_make_engine_value_converter_collections():
128+
# List of structs
129+
converter = build_engine_value_converter(list[Order])
130+
engine_val = [
131+
["O1", "item1", 10.0, "default_extra"],
132+
["O2", "item2", 20.0, "default_extra"]
133+
]
134+
assert converter(engine_val) == [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")]
135+
# Struct with list field
136+
converter = build_engine_value_converter(Customer)
137+
engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"], ["premium"]]]
138+
assert converter(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")])
139+
# Struct with struct field
140+
converter = build_engine_value_converter(NestedStruct)
141+
engine_val = [
142+
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
143+
[["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]],
144+
2
145+
]
146+
assert converter(engine_val) == NestedStruct(
147+
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
148+
[Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")],
149+
2
150+
)
151+
152+
def make_engine_order(fields):
153+
return make_dataclass('EngineOrder', fields)
154+
155+
def make_python_order(fields, defaults=None):
156+
if defaults is None:
157+
defaults = {}
158+
# Move all fields with defaults to the end (Python dataclass requirement)
159+
non_default_fields = [(n, t) for n, t in fields if n not in defaults]
160+
default_fields = [(n, t) for n, t in fields if n in defaults]
161+
ordered_fields = non_default_fields + default_fields
162+
# Prepare the namespace for defaults (only for fields at the end)
163+
namespace = {k: defaults[k] for k, _ in default_fields}
164+
return make_dataclass('PythonOrder', ordered_fields, namespace=namespace)
165+
166+
@pytest.mark.parametrize(
167+
"engine_fields, python_fields, python_defaults, engine_val, expected_python_val",
168+
[
169+
# Extra field in Python (middle)
170+
(
171+
[("id", str), ("name", str)],
172+
[("id", str), ("price", float), ("name", str)],
173+
{"price": 0.0},
174+
["O123", "mixed nuts"],
175+
("O123", 0.0, "mixed nuts"),
176+
),
177+
# Missing field in Python (middle)
178+
(
179+
[("id", str), ("price", float), ("name", str)],
180+
[("id", str), ("name", str)],
181+
{},
182+
["O123", 25.0, "mixed nuts"],
183+
("O123", "mixed nuts"),
184+
),
185+
# Extra field in Python (start)
186+
(
187+
[("name", str), ("price", float)],
188+
[("extra", str), ("name", str), ("price", float)],
189+
{"extra": "default"},
190+
["mixed nuts", 25.0],
191+
("default", "mixed nuts", 25.0),
192+
),
193+
# Missing field in Python (start)
194+
(
195+
[("extra", str), ("name", str), ("price", float)],
196+
[("name", str), ("price", float)],
197+
{},
198+
["unexpected", "mixed nuts", 25.0],
199+
("mixed nuts", 25.0),
200+
),
201+
# Field order difference (should map by name)
202+
(
203+
[("id", str), ("name", str), ("price", float)],
204+
[("name", str), ("id", str), ("price", float), ("extra", str)],
205+
{"extra": "default"},
206+
["O123", "mixed nuts", 25.0],
207+
("mixed nuts", "O123", 25.0, "default"),
208+
),
209+
# Extra field (Python has extra field with default)
210+
(
211+
[("id", str), ("name", str)],
212+
[("id", str), ("name", str), ("price", float)],
213+
{"price": 0.0},
214+
["O123", "mixed nuts"],
215+
("O123", "mixed nuts", 0.0),
216+
),
217+
# Missing field (Engine has extra field)
218+
(
219+
[("id", str), ("name", str), ("price", float)],
220+
[("id", str), ("name", str)],
221+
{},
222+
["O123", "mixed nuts", 25.0],
223+
("O123", "mixed nuts"),
224+
),
225+
]
226+
)
227+
def test_field_position_cases(engine_fields, python_fields, python_defaults, engine_val, expected_python_val):
228+
EngineOrder = make_engine_order(engine_fields)
229+
PythonOrder = make_python_order(python_fields, python_defaults)
230+
converter = build_engine_value_converter(EngineOrder, PythonOrder)
231+
# Map field names to expected values
232+
expected_dict = dict(zip([f[0] for f in python_fields], expected_python_val))
233+
# Instantiate using keyword arguments (order doesn't matter)
234+
assert converter(engine_val) == PythonOrder(**expected_dict)

0 commit comments

Comments
 (0)