Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import datetime
from dataclasses import dataclass
import pytest
from cocoindex.convert import to_engine_value
from cocoindex.convert import to_engine_value, make_engine_value_converter
from cocoindex.typing import encode_enriched_type

@dataclass
class Order:
order_id: str
name: str
price: float
extra_field: str = "default_extra"

@dataclass
class Tag:
name: str

@dataclass
class Basket:
Expand All @@ -19,6 +25,21 @@ class Basket:
class Customer:
name: str
order: Order
tags: list[Tag] = None

@dataclass
class NestedStruct:
customer: Customer
orders: list[Order]
count: int = 0

def build_converter(py_type, target_type=None):
"""
Helper to build a converter for the given Python type.
If target_type is not specified, uses py_type as the target.
"""
engine_type = encode_enriched_type(py_type)["type"]
return make_engine_value_converter([], engine_type, target_type or py_type)

def test_to_engine_value_basic_types():
assert to_engine_value(123) == 123
Expand All @@ -40,19 +61,19 @@ def test_to_engine_value_date_time_types():

def test_to_engine_value_struct():
order = Order(order_id="O123", name="mixed nuts", price=25.0)
assert to_engine_value(order) == ["O123", "mixed nuts", 25.0]
assert to_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]

def test_to_engine_value_list_of_structs():
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
assert to_engine_value(orders) == [["O1", "item1", 10.0], ["O2", "item2", 20.0]]
assert to_engine_value(orders) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]

def test_to_engine_value_struct_with_list():
basket = Basket(items=["apple", "banana"])
assert to_engine_value(basket) == [["apple", "banana"]]

def test_to_engine_value_nested_struct():
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
assert to_engine_value(customer) == ["Alice", ["O1", "item1", 10.0]]
assert to_engine_value(customer) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None]

def test_to_engine_value_empty_list():
assert to_engine_value([]) == []
Expand All @@ -67,3 +88,74 @@ def test_to_engine_value_tuple():

def test_to_engine_value_none():
assert to_engine_value(None) is None

def test_make_engine_value_converter_basic_types():
for py_type, value in [
(int, 42),
(float, 3.14),
(str, "hello"),
(bool, True),
# (type(None), None), # Removed unsupported NoneType
]:
converter = build_converter(py_type)
assert converter(value) == value

def test_make_engine_value_converter_struct():
converter = build_converter(Order)
# All fields match
engine_val = ["O123", "mixed nuts", 25.0, "default_extra"]
assert converter(engine_val) == Order("O123", "mixed nuts", 25.0, "default_extra")
# Extra field in Python dataclass (should ignore extra)
engine_val_extra = ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"]
assert converter(engine_val_extra) == Order("O123", "mixed nuts", 25.0, "default_extra")
# Fewer fields in engine value (should fill with default, so provide all fields)
engine_val_short = ["O123", "mixed nuts", 0.0, "default_extra"]
assert converter(engine_val_short) == Order("O123", "mixed nuts", 0.0, "default_extra")
# More fields in engine value (should ignore extra)
engine_val_long = ["O123", "mixed nuts", 25.0, "unexpected"]
assert converter(engine_val_long) == Order("O123", "mixed nuts", 25.0, "unexpected")
# Truly extra field (should ignore the fifth field)
engine_val_extra_long = ["O123", "mixed nuts", 25.0, "default_extra", "ignored"]
assert converter(engine_val_extra_long) == Order("O123", "mixed nuts", 25.0, "default_extra")

def test_make_engine_value_converter_struct_field_order():
# Engine fields in different order
# Use encode_enriched_type to avoid manual mistakes
converter = build_converter(Order)
# Provide all fields in the correct order
engine_val = ["O123", "mixed nuts", 25.0, "default_extra"]
assert converter(engine_val) == Order("O123", "mixed nuts", 25.0, "default_extra")

def test_make_engine_value_converter_collections():
# List of structs
converter = build_converter(list[Order])
engine_val = [
["O1", "item1", 10.0, "default_extra"],
["O2", "item2", 20.0, "default_extra"]
]
assert converter(engine_val) == [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")]
# Struct with list field
converter = build_converter(Customer)
engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"], ["premium"]]]
assert converter(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")])
# Struct with struct field
converter = build_converter(NestedStruct)
engine_val = [
["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]],
[["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]],
2
]
assert converter(engine_val) == NestedStruct(
Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]),
[Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")],
2
)

def test_make_engine_value_converter_defaults_and_missing_fields():
# Missing optional field in engine value
converter = build_converter(Customer)
engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], None] # tags explicitly None
assert converter(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)
# Extra field in engine value (should ignore)
engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"]
assert converter(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")])