Skip to content

Commit 382b1bf

Browse files
committed
support binding classes on Pydantic Models to CocoIndex Struct
1 parent 31e66d9 commit 382b1bf

File tree

4 files changed

+296
-15
lines changed

4 files changed

+296
-15
lines changed

docs/docs/core/data_types.mdx

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,14 @@ The Python type is `T1 | T2 | ...`, e.g. `cocoindex.Int64 | cocoindex.Float32 |
105105

106106
A *Struct* has a bunch of fields, each with a name and a type.
107107

108-
In Python, a *Struct* type is represented by either a [dataclass](https://docs.python.org/3/library/dataclasses.html)
109-
or a [NamedTuple](https://docs.python.org/3/library/typing.html#typing.NamedTuple), with all fields annotated with a specific type.
110-
Both options define a structured type with named fields, but they differ slightly:
108+
In Python, a *Struct* type is represented by either a [dataclass](https://docs.python.org/3/library/dataclasses.html),
109+
a [NamedTuple](https://docs.python.org/3/library/typing.html#typing.NamedTuple), or a [Pydantic model](https://pydantic.dev/),
110+
with all fields annotated with a specific type.
111+
These options define a structured type with named fields, but they differ slightly:
111112

112113
- **Dataclass**: A flexible class-based structure, mutable by default, defined using the `@dataclass` decorator.
113114
- **NamedTuple**: An immutable tuple-based structure, defined using `typing.NamedTuple`.
115+
- **Pydantic model**: A modern data validation and parsing structure, defined by inheriting from `pydantic.BaseModel`.
114116

115117
For example:
116118

@@ -131,10 +133,22 @@ class PersonTuple(NamedTuple):
131133
first_name: str
132134
last_name: str
133135
dob: datetime.date
136+
137+
# Using Pydantic (optional dependency)
138+
try:
139+
from pydantic import BaseModel
140+
141+
class PersonModel(BaseModel):
142+
first_name: str
143+
last_name: str
144+
dob: datetime.date
145+
except ImportError:
146+
# Pydantic is optional
147+
pass
134148
```
135149

136-
Both `Person` and `PersonTuple` are valid Struct types in CocoIndex, with identical schemas (three fields: `first_name` (Str), `last_name` (Str), `dob` (Date)).
137-
Choose `dataclass` for mutable objects or when you need additional methods, and `NamedTuple` for immutable, lightweight structures.
150+
All three examples (`Person`, `PersonTuple`, and `PersonModel`) are valid Struct types in CocoIndex, with identical schemas (three fields: `first_name` (Str), `last_name` (Str), `dob` (Date)).
151+
Choose `dataclass` for mutable objects, `NamedTuple` for immutable lightweight structures, or `Pydantic` for data validation and serialization features.
138152

139153
Besides, for arguments of custom functions, CocoIndex also supports using dictionaries (`dict[str, Any]`) to represent a *Struct* type.
140154
It's the default Python type if you don't annotate the function argument with a specific type.
@@ -165,10 +179,10 @@ When a specific type annotation is not provided:
165179
- The value binds to `dict[str, Any]`.
166180

167181

168-
For example, you can use `dict[str, Person]` or `dict[str, PersonTuple]` to represent a *KTable*, with 4 columns: key (*Str*), `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*).
182+
For example, you can use `dict[str, Person]`, `dict[str, PersonTuple]`, or `dict[str, PersonModel]` to represent a *KTable*, with 4 columns: key (*Str*), `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*).
169183
It's bound to `dict[str, dict[str, Any]]` if you don't annotate the function argument with a specific type.
170184

171-
Note that when using a Struct as the key, it must be immutable in Python. For a dataclass, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in. For example:
185+
Note that when using a Struct as the key, it must be immutable in Python. For a dataclass, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in. For Pydantic models, use `frozen=True` in the model configuration. For example:
172186

173187
```python
174188
@dataclass(frozen=True)
@@ -179,9 +193,20 @@ class PersonKey:
179193
class PersonKeyTuple(NamedTuple):
180194
id_kind: str
181195
id: str
196+
197+
# Pydantic frozen model (if available)
198+
try:
199+
from pydantic import BaseModel
200+
201+
class PersonKeyModel(BaseModel):
202+
model_config = {"frozen": True}
203+
id_kind: str
204+
id: str
205+
except ImportError:
206+
pass
182207
```
183208

184-
Then you can use `dict[PersonKey, Person]` or `dict[PersonKeyTuple, PersonTuple]` to represent a KTable keyed by both `id_kind` and `id`.
209+
Then you can use `dict[PersonKey, Person]`, `dict[PersonKeyTuple, PersonTuple]`, or `dict[PersonKeyModel, PersonModel]` to represent a KTable keyed by both `id_kind` and `id`.
185210
If you don't annotate the function argument with a specific type, it's bound to `dict[tuple[str, str], dict[str, Any]]`.
186211

187212

@@ -190,7 +215,7 @@ If you don't annotate the function argument with a specific type, it's bound to
190215
*LTable* is a *Table* type whose row order is preserved. *LTable* has no key column.
191216

192217
In Python, a *LTable* type is represented by `list[R]`, where `R` is the type binding to the *Struct* type representing the value fields of each row.
193-
For example, you can use `list[Person]` to represent a *LTable* with 3 columns: `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*).
218+
For example, you can use `list[Person]`, `list[PersonTuple]`, or `list[PersonModel]` to represent a *LTable* with 3 columns: `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*).
194219
It's bound to `list[dict[str, Any]]` if you don't annotate the function argument with a specific type.
195220

196221
## Key Types

python/cocoindex/convert.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
analyze_type_info,
2727
encode_enriched_type,
2828
is_namedtuple_type,
29+
is_pydantic_model,
2930
is_numpy_number_type,
3031
extract_ndarray_elem_dtype,
3132
ValueType,
@@ -167,6 +168,29 @@ def encode_namedtuple(value: Any) -> Any:
167168

168169
return encode_namedtuple
169170

171+
elif is_pydantic_model(struct_type):
172+
# Type guard: ensure we have model_fields attribute
173+
if hasattr(struct_type, "model_fields"):
174+
field_names = list(struct_type.model_fields.keys()) # type: ignore[attr-defined]
175+
field_encoders = [
176+
make_engine_value_encoder(
177+
analyze_type_info(struct_type.model_fields[name].annotation) # type: ignore[attr-defined]
178+
)
179+
for name in field_names
180+
]
181+
else:
182+
raise ValueError(f"Invalid Pydantic model: {struct_type}")
183+
184+
def encode_pydantic(value: Any) -> Any:
185+
if value is None:
186+
return None
187+
return [
188+
encoder(getattr(value, name))
189+
for encoder, name in zip(field_encoders, field_names)
190+
]
191+
192+
return encode_pydantic
193+
170194
def encode_basic_value(value: Any) -> Any:
171195
if isinstance(value, np.number):
172196
return value.item()
@@ -472,7 +496,7 @@ def make_engine_struct_decoder(
472496
if not isinstance(dst_type_variant, AnalyzedStructType):
473497
raise ValueError(
474498
f"Type mismatch for `{''.join(field_path)}`: "
475-
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
499+
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple, Pydantic model or dict[str, Any] expected"
476500
)
477501

478502
src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)}
@@ -495,6 +519,26 @@ def make_engine_struct_decoder(
495519
)
496520
for name in fields
497521
}
522+
elif is_pydantic_model(dst_struct_type):
523+
# For Pydantic models, we can use model_fields to get field information
524+
parameters = {}
525+
# Type guard: ensure we have model_fields attribute
526+
if hasattr(dst_struct_type, "model_fields"):
527+
model_fields = dst_struct_type.model_fields # type: ignore[attr-defined]
528+
else:
529+
model_fields = {}
530+
for name, field_info in model_fields.items():
531+
default_value = (
532+
field_info.default
533+
if field_info.default is not ...
534+
else inspect.Parameter.empty
535+
)
536+
parameters[name] = inspect.Parameter(
537+
name=name,
538+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
539+
default=default_value,
540+
annotation=field_info.annotation,
541+
)
498542
else:
499543
raise ValueError(f"Unsupported struct type: {dst_struct_type}")
500544

@@ -536,9 +580,21 @@ def make_closure_for_field(
536580
make_closure_for_field(name, param) for (name, param) in parameters.items()
537581
]
538582

539-
return lambda values: dst_struct_type(
540-
*(decoder(values) for decoder in field_value_decoder)
541-
)
583+
# Different construction for different struct types
584+
if is_pydantic_model(dst_struct_type):
585+
# Pydantic models prefer keyword arguments
586+
field_names = list(parameters.keys())
587+
return lambda values: dst_struct_type(
588+
**{
589+
field_names[i]: decoder(values)
590+
for i, decoder in enumerate(field_value_decoder)
591+
}
592+
)
593+
else:
594+
# Dataclasses and NamedTuples can use positional arguments
595+
return lambda values: dst_struct_type(
596+
*(decoder(values) for decoder in field_value_decoder)
597+
)
542598

543599

544600
def _make_engine_struct_to_dict_decoder(
@@ -718,7 +774,7 @@ def load_engine_object(expected_type: Any, v: Any) -> Any:
718774
for k, val in v.items()
719775
}
720776

721-
# Structs (dataclass or NamedTuple)
777+
# Structs (dataclass, NamedTuple, or Pydantic)
722778
if isinstance(variant, AnalyzedStructType):
723779
struct_type = variant.struct_type
724780
if dataclasses.is_dataclass(struct_type):
@@ -743,6 +799,23 @@ def load_engine_object(expected_type: Any, v: Any) -> Any:
743799
if name in v:
744800
nt_init_kwargs[name] = load_engine_object(f_type, v[name])
745801
return struct_type(**nt_init_kwargs)
802+
elif is_pydantic_model(struct_type):
803+
if not isinstance(v, Mapping):
804+
raise ValueError(f"Expected dict for Pydantic model, got {type(v)}")
805+
# Drop auxiliary discriminator "kind" if present
806+
pydantic_init_kwargs: dict[str, Any] = {}
807+
# Type guard: ensure we have model_fields attribute
808+
if hasattr(struct_type, "model_fields"):
809+
model_fields = struct_type.model_fields # type: ignore[attr-defined]
810+
else:
811+
model_fields = {}
812+
field_types = {
813+
name: field.annotation for name, field in model_fields.items()
814+
}
815+
for name, f_type in field_types.items():
816+
if name in v:
817+
pydantic_init_kwargs[name] = load_engine_object(f_type, v[name])
818+
return struct_type(**pydantic_init_kwargs)
746819
return v
747820

748821
# Union with discriminator support via "kind"

python/cocoindex/tests/test_convert.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88
import pytest
99
from numpy.typing import NDArray
1010

11+
# Optional Pydantic support for testing
12+
try:
13+
from pydantic import BaseModel, Field
14+
15+
PYDANTIC_AVAILABLE = True
16+
except ImportError:
17+
BaseModel = None # type: ignore[misc,assignment]
18+
Field = None # type: ignore[misc,assignment]
19+
PYDANTIC_AVAILABLE = False
20+
1121
import cocoindex
1222
from cocoindex.convert import (
1323
dump_engine_object,
@@ -70,6 +80,29 @@ class CustomerNamedTuple(NamedTuple):
7080
tags: list[Tag] | None = None
7181

7282

83+
# Pydantic model definitions (if available)
84+
if PYDANTIC_AVAILABLE:
85+
86+
class OrderPydantic(BaseModel):
87+
order_id: str
88+
name: str
89+
price: float
90+
extra_field: str = "default_extra"
91+
92+
class TagPydantic(BaseModel):
93+
name: str
94+
95+
class CustomerPydantic(BaseModel):
96+
name: str
97+
order: OrderPydantic
98+
tags: list[TagPydantic] | None = None
99+
100+
class NestedStructPydantic(BaseModel):
101+
customer: CustomerPydantic
102+
orders: list[OrderPydantic]
103+
count: int = 0
104+
105+
73106
def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
74107
"""
75108
Encode a Python value to an engine value.
@@ -1566,3 +1599,119 @@ class UnsupportedField:
15661599
match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
15671600
):
15681601
build_engine_value_decoder(Base, UnsupportedField)
1602+
1603+
1604+
# Pydantic model tests
1605+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1606+
def test_pydantic_simple_struct() -> None:
1607+
"""Test basic Pydantic model encoding and decoding."""
1608+
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1609+
validate_full_roundtrip(order, OrderPydantic)
1610+
1611+
1612+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1613+
def test_pydantic_struct_with_defaults() -> None:
1614+
"""Test Pydantic model with default values."""
1615+
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1616+
assert order.extra_field == "default_extra"
1617+
validate_full_roundtrip(order, OrderPydantic)
1618+
1619+
1620+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1621+
def test_pydantic_nested_struct() -> None:
1622+
"""Test nested Pydantic models."""
1623+
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1624+
customer = CustomerPydantic(name="Alice", order=order)
1625+
validate_full_roundtrip(customer, CustomerPydantic)
1626+
1627+
1628+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1629+
def test_pydantic_struct_with_list() -> None:
1630+
"""Test Pydantic model with list fields."""
1631+
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1632+
tags = [TagPydantic(name="vip"), TagPydantic(name="premium")]
1633+
customer = CustomerPydantic(name="Alice", order=order, tags=tags)
1634+
validate_full_roundtrip(customer, CustomerPydantic)
1635+
1636+
1637+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1638+
def test_pydantic_complex_nested_struct() -> None:
1639+
"""Test complex nested Pydantic structure."""
1640+
order1 = OrderPydantic(order_id="O1", name="item1", price=10.0)
1641+
order2 = OrderPydantic(order_id="O2", name="item2", price=20.0)
1642+
customer = CustomerPydantic(
1643+
name="Alice", order=order1, tags=[TagPydantic(name="vip")]
1644+
)
1645+
nested = NestedStructPydantic(customer=customer, orders=[order1, order2], count=2)
1646+
validate_full_roundtrip(nested, NestedStructPydantic)
1647+
1648+
1649+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1650+
def test_pydantic_struct_to_dict_binding() -> None:
1651+
"""Test Pydantic model -> dict binding."""
1652+
order = OrderPydantic(order_id="O1", name="item1", price=10.0, extra_field="custom")
1653+
expected_dict = {
1654+
"order_id": "O1",
1655+
"name": "item1",
1656+
"price": 10.0,
1657+
"extra_field": "custom",
1658+
}
1659+
1660+
validate_full_roundtrip(
1661+
order,
1662+
OrderPydantic,
1663+
(expected_dict, Any),
1664+
(expected_dict, dict),
1665+
(expected_dict, dict[Any, Any]),
1666+
(expected_dict, dict[str, Any]),
1667+
)
1668+
1669+
1670+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1671+
def test_make_engine_value_decoder_pydantic_struct() -> None:
1672+
"""Test engine value decoder for Pydantic models."""
1673+
engine_val = ["O1", "item1", 10.0, "default_extra"]
1674+
decoder = build_engine_value_decoder(OrderPydantic)
1675+
result = decoder(engine_val)
1676+
1677+
assert isinstance(result, OrderPydantic)
1678+
assert result.order_id == "O1"
1679+
assert result.name == "item1"
1680+
assert result.price == 10.0
1681+
assert result.extra_field == "default_extra"
1682+
1683+
1684+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1685+
def test_make_engine_value_decoder_pydantic_nested() -> None:
1686+
"""Test engine value decoder for nested Pydantic models."""
1687+
engine_val = [
1688+
"Alice",
1689+
["O1", "item1", 10.0, "default_extra"],
1690+
[["vip"]],
1691+
]
1692+
decoder = build_engine_value_decoder(CustomerPydantic)
1693+
result = decoder(engine_val)
1694+
1695+
assert isinstance(result, CustomerPydantic)
1696+
assert result.name == "Alice"
1697+
assert isinstance(result.order, OrderPydantic)
1698+
assert result.order.order_id == "O1"
1699+
assert result.tags is not None
1700+
assert len(result.tags) == 1
1701+
assert isinstance(result.tags[0], TagPydantic)
1702+
assert result.tags[0].name == "vip"
1703+
1704+
1705+
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
1706+
def test_pydantic_mixed_with_dataclass() -> None:
1707+
"""Test mixing Pydantic models with dataclasses."""
1708+
1709+
# Create a dataclass that uses a Pydantic model
1710+
@dataclass
1711+
class MixedStruct:
1712+
name: str
1713+
pydantic_order: OrderPydantic
1714+
1715+
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
1716+
mixed = MixedStruct(name="test", pydantic_order=order)
1717+
validate_full_roundtrip(mixed, MixedStruct)

0 commit comments

Comments
 (0)