|
8 | 8 | import pytest |
9 | 9 | from numpy.typing import NDArray |
10 | 10 |
|
| 11 | +# Optional Pydantic support for testing |
| 12 | +try: |
| 13 | + from pydantic import BaseModel, Field |
| 14 | + |
| 15 | + PYDANTIC_AVAILABLE = True |
| 16 | +except ImportError: |
| 17 | + BaseModel = None # type: ignore[misc,assignment] |
| 18 | + Field = None # type: ignore[misc,assignment] |
| 19 | + PYDANTIC_AVAILABLE = False |
| 20 | + |
11 | 21 | import cocoindex |
12 | 22 | from cocoindex.convert import ( |
13 | 23 | dump_engine_object, |
@@ -70,6 +80,29 @@ class CustomerNamedTuple(NamedTuple): |
70 | 80 | tags: list[Tag] | None = None |
71 | 81 |
|
72 | 82 |
|
| 83 | +# Pydantic model definitions (if available) |
| 84 | +if PYDANTIC_AVAILABLE: |
| 85 | + |
| 86 | + class OrderPydantic(BaseModel): |
| 87 | + order_id: str |
| 88 | + name: str |
| 89 | + price: float |
| 90 | + extra_field: str = "default_extra" |
| 91 | + |
| 92 | + class TagPydantic(BaseModel): |
| 93 | + name: str |
| 94 | + |
| 95 | + class CustomerPydantic(BaseModel): |
| 96 | + name: str |
| 97 | + order: OrderPydantic |
| 98 | + tags: list[TagPydantic] | None = None |
| 99 | + |
| 100 | + class NestedStructPydantic(BaseModel): |
| 101 | + customer: CustomerPydantic |
| 102 | + orders: list[OrderPydantic] |
| 103 | + count: int = 0 |
| 104 | + |
| 105 | + |
73 | 106 | def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any: |
74 | 107 | """ |
75 | 108 | Encode a Python value to an engine value. |
@@ -1566,3 +1599,119 @@ class UnsupportedField: |
1566 | 1599 | match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ", |
1567 | 1600 | ): |
1568 | 1601 | build_engine_value_decoder(Base, UnsupportedField) |
| 1602 | + |
| 1603 | + |
| 1604 | +# Pydantic model tests |
| 1605 | +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") |
| 1606 | +def test_pydantic_simple_struct() -> None: |
| 1607 | + """Test basic Pydantic model encoding and decoding.""" |
| 1608 | + order = OrderPydantic(order_id="O1", name="item1", price=10.0) |
| 1609 | + validate_full_roundtrip(order, OrderPydantic) |
| 1610 | + |
| 1611 | + |
| 1612 | +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") |
| 1613 | +def test_pydantic_struct_with_defaults() -> None: |
| 1614 | + """Test Pydantic model with default values.""" |
| 1615 | + order = OrderPydantic(order_id="O1", name="item1", price=10.0) |
| 1616 | + assert order.extra_field == "default_extra" |
| 1617 | + validate_full_roundtrip(order, OrderPydantic) |
| 1618 | + |
| 1619 | + |
| 1620 | +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") |
| 1621 | +def test_pydantic_nested_struct() -> None: |
| 1622 | + """Test nested Pydantic models.""" |
| 1623 | + order = OrderPydantic(order_id="O1", name="item1", price=10.0) |
| 1624 | + customer = CustomerPydantic(name="Alice", order=order) |
| 1625 | + validate_full_roundtrip(customer, CustomerPydantic) |
| 1626 | + |
| 1627 | + |
| 1628 | +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") |
| 1629 | +def test_pydantic_struct_with_list() -> None: |
| 1630 | + """Test Pydantic model with list fields.""" |
| 1631 | + order = OrderPydantic(order_id="O1", name="item1", price=10.0) |
| 1632 | + tags = [TagPydantic(name="vip"), TagPydantic(name="premium")] |
| 1633 | + customer = CustomerPydantic(name="Alice", order=order, tags=tags) |
| 1634 | + validate_full_roundtrip(customer, CustomerPydantic) |
| 1635 | + |
| 1636 | + |
| 1637 | +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") |
| 1638 | +def test_pydantic_complex_nested_struct() -> None: |
| 1639 | + """Test complex nested Pydantic structure.""" |
| 1640 | + order1 = OrderPydantic(order_id="O1", name="item1", price=10.0) |
| 1641 | + order2 = OrderPydantic(order_id="O2", name="item2", price=20.0) |
| 1642 | + customer = CustomerPydantic( |
| 1643 | + name="Alice", order=order1, tags=[TagPydantic(name="vip")] |
| 1644 | + ) |
| 1645 | + nested = NestedStructPydantic(customer=customer, orders=[order1, order2], count=2) |
| 1646 | + validate_full_roundtrip(nested, NestedStructPydantic) |
| 1647 | + |
| 1648 | + |
| 1649 | +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") |
| 1650 | +def test_pydantic_struct_to_dict_binding() -> None: |
| 1651 | + """Test Pydantic model -> dict binding.""" |
| 1652 | + order = OrderPydantic(order_id="O1", name="item1", price=10.0, extra_field="custom") |
| 1653 | + expected_dict = { |
| 1654 | + "order_id": "O1", |
| 1655 | + "name": "item1", |
| 1656 | + "price": 10.0, |
| 1657 | + "extra_field": "custom", |
| 1658 | + } |
| 1659 | + |
| 1660 | + validate_full_roundtrip( |
| 1661 | + order, |
| 1662 | + OrderPydantic, |
| 1663 | + (expected_dict, Any), |
| 1664 | + (expected_dict, dict), |
| 1665 | + (expected_dict, dict[Any, Any]), |
| 1666 | + (expected_dict, dict[str, Any]), |
| 1667 | + ) |
| 1668 | + |
| 1669 | + |
| 1670 | +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") |
| 1671 | +def test_make_engine_value_decoder_pydantic_struct() -> None: |
| 1672 | + """Test engine value decoder for Pydantic models.""" |
| 1673 | + engine_val = ["O1", "item1", 10.0, "default_extra"] |
| 1674 | + decoder = build_engine_value_decoder(OrderPydantic) |
| 1675 | + result = decoder(engine_val) |
| 1676 | + |
| 1677 | + assert isinstance(result, OrderPydantic) |
| 1678 | + assert result.order_id == "O1" |
| 1679 | + assert result.name == "item1" |
| 1680 | + assert result.price == 10.0 |
| 1681 | + assert result.extra_field == "default_extra" |
| 1682 | + |
| 1683 | + |
| 1684 | +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") |
| 1685 | +def test_make_engine_value_decoder_pydantic_nested() -> None: |
| 1686 | + """Test engine value decoder for nested Pydantic models.""" |
| 1687 | + engine_val = [ |
| 1688 | + "Alice", |
| 1689 | + ["O1", "item1", 10.0, "default_extra"], |
| 1690 | + [["vip"]], |
| 1691 | + ] |
| 1692 | + decoder = build_engine_value_decoder(CustomerPydantic) |
| 1693 | + result = decoder(engine_val) |
| 1694 | + |
| 1695 | + assert isinstance(result, CustomerPydantic) |
| 1696 | + assert result.name == "Alice" |
| 1697 | + assert isinstance(result.order, OrderPydantic) |
| 1698 | + assert result.order.order_id == "O1" |
| 1699 | + assert result.tags is not None |
| 1700 | + assert len(result.tags) == 1 |
| 1701 | + assert isinstance(result.tags[0], TagPydantic) |
| 1702 | + assert result.tags[0].name == "vip" |
| 1703 | + |
| 1704 | + |
| 1705 | +@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available") |
| 1706 | +def test_pydantic_mixed_with_dataclass() -> None: |
| 1707 | + """Test mixing Pydantic models with dataclasses.""" |
| 1708 | + |
| 1709 | + # Create a dataclass that uses a Pydantic model |
| 1710 | + @dataclass |
| 1711 | + class MixedStruct: |
| 1712 | + name: str |
| 1713 | + pydantic_order: OrderPydantic |
| 1714 | + |
| 1715 | + order = OrderPydantic(order_id="O1", name="item1", price=10.0) |
| 1716 | + mixed = MixedStruct(name="test", pydantic_order=order) |
| 1717 | + validate_full_roundtrip(mixed, MixedStruct) |
0 commit comments