Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions docs/docs/core/data_types.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ The Python type is `T1 | T2 | ...`, e.g. `cocoindex.Int64 | cocoindex.Float32 |

A *Struct* has a bunch of fields, each with a name and a type.

In Python, a *Struct* type is represented by either a [dataclass](https://docs.python.org/3/library/dataclasses.html)
or a [NamedTuple](https://docs.python.org/3/library/typing.html#typing.NamedTuple), with all fields annotated with a specific type.
Both options define a structured type with named fields, but they differ slightly:
In Python, a *Struct* type is represented by either a [dataclass](https://docs.python.org/3/library/dataclasses.html),
a [NamedTuple](https://docs.python.org/3/library/typing.html#typing.NamedTuple), or a [Pydantic model](https://pydantic.dev/),
with all fields annotated with a specific type.
These options define a structured type with named fields, but they differ slightly:

- **Dataclass**: A flexible class-based structure, mutable by default, defined using the `@dataclass` decorator.
- **NamedTuple**: An immutable tuple-based structure, defined using `typing.NamedTuple`.
- **Pydantic model**: A modern data validation and parsing structure, defined by inheriting from `pydantic.BaseModel`.

For example:

Expand All @@ -131,10 +133,22 @@ class PersonTuple(NamedTuple):
first_name: str
last_name: str
dob: datetime.date

# Using Pydantic (optional dependency)
try:
from pydantic import BaseModel

class PersonModel(BaseModel):
first_name: str
last_name: str
dob: datetime.date
except ImportError:
# Pydantic is optional
pass
```

Both `Person` and `PersonTuple` are valid Struct types in CocoIndex, with identical schemas (three fields: `first_name` (Str), `last_name` (Str), `dob` (Date)).
Choose `dataclass` for mutable objects or when you need additional methods, and `NamedTuple` for immutable, lightweight structures.
All three examples (`Person`, `PersonTuple`, and `PersonModel`) are valid Struct types in CocoIndex, with identical schemas (three fields: `first_name` (Str), `last_name` (Str), `dob` (Date)).
Choose `dataclass` for mutable objects, `NamedTuple` for immutable lightweight structures, or `Pydantic` for data validation and serialization features.

Besides, for arguments of custom functions, CocoIndex also supports using dictionaries (`dict[str, Any]`) to represent a *Struct* type.
It's the default Python type if you don't annotate the function argument with a specific type.
Expand Down Expand Up @@ -165,10 +179,10 @@ When a specific type annotation is not provided:
- The value binds to `dict[str, Any]`.


For example, you can use `dict[str, Person]` or `dict[str, PersonTuple]` to represent a *KTable*, with 4 columns: key (*Str*), `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*).
For example, you can use `dict[str, Person]`, `dict[str, PersonTuple]`, or `dict[str, PersonModel]` to represent a *KTable*, with 4 columns: key (*Str*), `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*).
It's bound to `dict[str, dict[str, Any]]` if you don't annotate the function argument with a specific type.

Note that when using a Struct as the key, it must be immutable in Python. For a dataclass, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in. For example:
Note that when using a Struct as the key, it must be immutable in Python. For a dataclass, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in. For Pydantic models, use `frozen=True` in the model configuration. For example:

```python
@dataclass(frozen=True)
Expand All @@ -179,9 +193,20 @@ class PersonKey:
class PersonKeyTuple(NamedTuple):
id_kind: str
id: str

# Pydantic frozen model (if available)
try:
from pydantic import BaseModel

class PersonKeyModel(BaseModel):
model_config = {"frozen": True}
id_kind: str
id: str
except ImportError:
pass
```

Then you can use `dict[PersonKey, Person]` or `dict[PersonKeyTuple, PersonTuple]` to represent a KTable keyed by both `id_kind` and `id`.
Then you can use `dict[PersonKey, Person]`, `dict[PersonKeyTuple, PersonTuple]`, or `dict[PersonKeyModel, PersonModel]` to represent a KTable keyed by both `id_kind` and `id`.
If you don't annotate the function argument with a specific type, it's bound to `dict[tuple[str, str], dict[str, Any]]`.


Expand All @@ -190,7 +215,7 @@ If you don't annotate the function argument with a specific type, it's bound to
*LTable* is a *Table* type whose row order is preserved. *LTable* has no key column.

In Python, a *LTable* type is represented by `list[R]`, where `R` is the type binding to the *Struct* type representing the value fields of each row.
For example, you can use `list[Person]` to represent a *LTable* with 3 columns: `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*).
For example, you can use `list[Person]`, `list[PersonTuple]`, or `list[PersonModel]` to represent a *LTable* with 3 columns: `first_name` (*Str*), `last_name` (*Str*), `dob` (*Date*).
It's bound to `list[dict[str, Any]]` if you don't annotate the function argument with a specific type.

## Key Types
Expand Down
83 changes: 78 additions & 5 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
analyze_type_info,
encode_enriched_type,
is_namedtuple_type,
is_pydantic_model,
is_numpy_number_type,
extract_ndarray_elem_dtype,
ValueType,
Expand Down Expand Up @@ -167,6 +168,29 @@ def encode_namedtuple(value: Any) -> Any:

return encode_namedtuple

elif is_pydantic_model(struct_type):
# Type guard: ensure we have model_fields attribute
if hasattr(struct_type, "model_fields"):
field_names = list(struct_type.model_fields.keys()) # type: ignore[attr-defined]
field_encoders = [
make_engine_value_encoder(
analyze_type_info(struct_type.model_fields[name].annotation) # type: ignore[attr-defined]
)
for name in field_names
]
else:
raise ValueError(f"Invalid Pydantic model: {struct_type}")

def encode_pydantic(value: Any) -> Any:
if value is None:
return None
return [
encoder(getattr(value, name))
for encoder, name in zip(field_encoders, field_names)
]

return encode_pydantic

def encode_basic_value(value: Any) -> Any:
if isinstance(value, np.number):
return value.item()
Expand Down Expand Up @@ -472,7 +496,7 @@ def make_engine_struct_decoder(
if not isinstance(dst_type_variant, AnalyzedStructType):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple or dict[str, Any] expected"
f"declared `{dst_type_info.core_type}`, a dataclass, NamedTuple, Pydantic model or dict[str, Any] expected"
)

src_name_to_idx = {f.name: i for i, f in enumerate(src_fields)}
Expand All @@ -495,6 +519,26 @@ def make_engine_struct_decoder(
)
for name in fields
}
elif is_pydantic_model(dst_struct_type):
# For Pydantic models, we can use model_fields to get field information
parameters = {}
# Type guard: ensure we have model_fields attribute
if hasattr(dst_struct_type, "model_fields"):
model_fields = dst_struct_type.model_fields # type: ignore[attr-defined]
else:
model_fields = {}
for name, field_info in model_fields.items():
default_value = (
field_info.default
if field_info.default is not ...
else inspect.Parameter.empty
)
parameters[name] = inspect.Parameter(
name=name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=default_value,
annotation=field_info.annotation,
)
else:
raise ValueError(f"Unsupported struct type: {dst_struct_type}")

Expand Down Expand Up @@ -536,9 +580,21 @@ def make_closure_for_field(
make_closure_for_field(name, param) for (name, param) in parameters.items()
]

return lambda values: dst_struct_type(
*(decoder(values) for decoder in field_value_decoder)
)
# Different construction for different struct types
if is_pydantic_model(dst_struct_type):
# Pydantic models prefer keyword arguments
field_names = list(parameters.keys())
return lambda values: dst_struct_type(
**{
field_names[i]: decoder(values)
for i, decoder in enumerate(field_value_decoder)
}
)
else:
# Dataclasses and NamedTuples can use positional arguments
return lambda values: dst_struct_type(
*(decoder(values) for decoder in field_value_decoder)
)


def _make_engine_struct_to_dict_decoder(
Expand Down Expand Up @@ -718,7 +774,7 @@ def load_engine_object(expected_type: Any, v: Any) -> Any:
for k, val in v.items()
}

# Structs (dataclass or NamedTuple)
# Structs (dataclass, NamedTuple, or Pydantic)
if isinstance(variant, AnalyzedStructType):
struct_type = variant.struct_type
if dataclasses.is_dataclass(struct_type):
Expand All @@ -743,6 +799,23 @@ def load_engine_object(expected_type: Any, v: Any) -> Any:
if name in v:
nt_init_kwargs[name] = load_engine_object(f_type, v[name])
return struct_type(**nt_init_kwargs)
elif is_pydantic_model(struct_type):
if not isinstance(v, Mapping):
raise ValueError(f"Expected dict for Pydantic model, got {type(v)}")
# Drop auxiliary discriminator "kind" if present
pydantic_init_kwargs: dict[str, Any] = {}
# Type guard: ensure we have model_fields attribute
if hasattr(struct_type, "model_fields"):
model_fields = struct_type.model_fields # type: ignore[attr-defined]
else:
model_fields = {}
field_types = {
name: field.annotation for name, field in model_fields.items()
}
for name, f_type in field_types.items():
if name in v:
pydantic_init_kwargs[name] = load_engine_object(f_type, v[name])
return struct_type(**pydantic_init_kwargs)
return v

# Union with discriminator support via "kind"
Expand Down
149 changes: 149 additions & 0 deletions python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
import pytest
from numpy.typing import NDArray

# Optional Pydantic support for testing
try:
from pydantic import BaseModel, Field

PYDANTIC_AVAILABLE = True
except ImportError:
BaseModel = None # type: ignore[misc,assignment]
Field = None # type: ignore[misc,assignment]
PYDANTIC_AVAILABLE = False

import cocoindex
from cocoindex.convert import (
dump_engine_object,
Expand Down Expand Up @@ -70,6 +80,29 @@ class CustomerNamedTuple(NamedTuple):
tags: list[Tag] | None = None


# Pydantic model definitions (if available)
if PYDANTIC_AVAILABLE:

class OrderPydantic(BaseModel):
order_id: str
name: str
price: float
extra_field: str = "default_extra"

class TagPydantic(BaseModel):
name: str

class CustomerPydantic(BaseModel):
name: str
order: OrderPydantic
tags: list[TagPydantic] | None = None

class NestedStructPydantic(BaseModel):
customer: CustomerPydantic
orders: list[OrderPydantic]
count: int = 0


def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
"""
Encode a Python value to an engine value.
Expand Down Expand Up @@ -1566,3 +1599,119 @@ class UnsupportedField:
match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
):
build_engine_value_decoder(Base, UnsupportedField)


# Pydantic model tests
@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_pydantic_simple_struct() -> None:
"""Test basic Pydantic model encoding and decoding."""
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
validate_full_roundtrip(order, OrderPydantic)


@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_pydantic_struct_with_defaults() -> None:
"""Test Pydantic model with default values."""
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
assert order.extra_field == "default_extra"
validate_full_roundtrip(order, OrderPydantic)


@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_pydantic_nested_struct() -> None:
"""Test nested Pydantic models."""
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
customer = CustomerPydantic(name="Alice", order=order)
validate_full_roundtrip(customer, CustomerPydantic)


@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_pydantic_struct_with_list() -> None:
"""Test Pydantic model with list fields."""
order = OrderPydantic(order_id="O1", name="item1", price=10.0)
tags = [TagPydantic(name="vip"), TagPydantic(name="premium")]
customer = CustomerPydantic(name="Alice", order=order, tags=tags)
validate_full_roundtrip(customer, CustomerPydantic)


@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_pydantic_complex_nested_struct() -> None:
"""Test complex nested Pydantic structure."""
order1 = OrderPydantic(order_id="O1", name="item1", price=10.0)
order2 = OrderPydantic(order_id="O2", name="item2", price=20.0)
customer = CustomerPydantic(
name="Alice", order=order1, tags=[TagPydantic(name="vip")]
)
nested = NestedStructPydantic(customer=customer, orders=[order1, order2], count=2)
validate_full_roundtrip(nested, NestedStructPydantic)


@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_pydantic_struct_to_dict_binding() -> None:
"""Test Pydantic model -> dict binding."""
order = OrderPydantic(order_id="O1", name="item1", price=10.0, extra_field="custom")
expected_dict = {
"order_id": "O1",
"name": "item1",
"price": 10.0,
"extra_field": "custom",
}

validate_full_roundtrip(
order,
OrderPydantic,
(expected_dict, Any),
(expected_dict, dict),
(expected_dict, dict[Any, Any]),
(expected_dict, dict[str, Any]),
)


@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_make_engine_value_decoder_pydantic_struct() -> None:
"""Test engine value decoder for Pydantic models."""
engine_val = ["O1", "item1", 10.0, "default_extra"]
decoder = build_engine_value_decoder(OrderPydantic)
result = decoder(engine_val)

assert isinstance(result, OrderPydantic)
assert result.order_id == "O1"
assert result.name == "item1"
assert result.price == 10.0
assert result.extra_field == "default_extra"


@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_make_engine_value_decoder_pydantic_nested() -> None:
"""Test engine value decoder for nested Pydantic models."""
engine_val = [
"Alice",
["O1", "item1", 10.0, "default_extra"],
[["vip"]],
]
decoder = build_engine_value_decoder(CustomerPydantic)
result = decoder(engine_val)

assert isinstance(result, CustomerPydantic)
assert result.name == "Alice"
assert isinstance(result.order, OrderPydantic)
assert result.order.order_id == "O1"
assert result.tags is not None
assert len(result.tags) == 1
assert isinstance(result.tags[0], TagPydantic)
assert result.tags[0].name == "vip"


@pytest.mark.skipif(not PYDANTIC_AVAILABLE, reason="Pydantic not available")
def test_pydantic_mixed_with_dataclass() -> None:
"""Test mixing Pydantic models with dataclasses."""

# Create a dataclass that uses a Pydantic model
@dataclass
class MixedStruct:
name: str
pydantic_order: OrderPydantic

order = OrderPydantic(order_id="O1", name="item1", price=10.0)
mixed = MixedStruct(name="test", pydantic_order=order)
validate_full_roundtrip(mixed, MixedStruct)
Loading