diff --git a/docs/docs/core/data_types.mdx b/docs/docs/core/data_types.mdx index ffbeb6f10..21d36c5a9 100644 --- a/docs/docs/core/data_types.mdx +++ b/docs/docs/core/data_types.mdx @@ -53,23 +53,41 @@ The native Python type is always more permissive and can represent a superset of you can choose whatever to use. The native Python type is usually simpler. -### Struct Type +### Struct Types A Struct has a bunch of fields, each with a name and a type. -In Python, a Struct type is represented by a [dataclass](https://docs.python.org/3/library/dataclasses.html), -and all fields must be annotated with a specific type. For example: +In Python, a Struct type is represented by either a [dataclass](https://docs.python.org/3/library/dataclasses.html) +or a [NamedTuple](https://docs.python.org/3/library/typing.html#typing.NamedTuple), with all fields annotated with a specific type. +Both options define a structured type with named fields, but they differ slightly: + +- **Dataclass**: A flexible class-based structure, mutable by default, defined using the `@dataclass` decorator. +- **NamedTuple**: An immutable tuple-based structure, defined using `typing.NamedTuple`. + +For example: ```python from dataclasses import dataclass +from typing import NamedTuple +import datetime +# Using dataclass @dataclass class Person: first_name: str - last_name + last_name: str + dob: datetime.date + +# Using NamedTuple +class PersonTuple(NamedTuple): + first_name: str + last_name: str dob: datetime.date ``` +Both `Person` and `PersonTuple` are valid Struct types in CocoIndex, with identical schemas (three fields: `first_name` (Str), `last_name` (Str), `dob` (Date)). +Choose `dataclass` for mutable objects or when you need additional methods, and `NamedTuple` for immutable, lightweight structures. + ### Table Types A Table type models a collection of rows, each with multiple columns. @@ -84,10 +102,10 @@ The row order of a KTable is not preserved. Type of the first column (key column) must be a [key type](#key-types). In Python, a KTable type is represented by `dict[K, V]`. -The `V` should be a dataclass, representing the value fields of each row. -For example, you can use `dict[str, Person]` to represent a KTable, with 4 columns: key (Str), `first_name` (Str), `last_name` (Str), `dob` (Date). +The `V` should be a struct type, either a `dataclass` or `NamedTuple`, representing the value fields of each row. +For example, you can use `dict[str, Person]` or `dict[str, PersonTuple]` to represent a KTable, with 4 columns: key (Str), `first_name` (Str), `last_name` (Str), `dob` (Date). -Note that if you want to use a struct as the key, you need to annotate the struct with `@dataclass(frozen=True)`, so the values are immutable. +Note that if you want to use a struct as the key, you need to ensure the struct is immutable. For `dataclass`, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in. For example: ```python @@ -95,9 +113,13 @@ For example: class PersonKey: id_kind: str id: str + +class PersonKeyTuple(NamedTuple): + id_kind: str + id: str ``` -Then you can use `dict[PersonKey, Person]` to represent a KTable keyed by `PersonKey`. +Then you can use `dict[PersonKey, Person]` or `dict[PersonKeyTuple, PersonTuple]` to represent a KTable keyed by `PersonKey` or `PersonKeyTuple`. #### LTable @@ -118,4 +140,4 @@ Currently, the following types are key types - Range - Uuid - Date -- Struct with all fields being key types +- Struct with all fields being key types (using `@dataclass(frozen=True)` or `NamedTuple`) diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 1f947323d..fad21b653 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -8,13 +8,15 @@ from enum import Enum from typing import Any, Callable, get_origin -from .typing import analyze_type_info, encode_enriched_type, TABLE_TYPES, KEY_FIELD_NAME +from .typing import analyze_type_info, encode_enriched_type, is_namedtuple_type, TABLE_TYPES, KEY_FIELD_NAME def encode_engine_value(value: Any) -> Any: """Encode a Python value to an engine value.""" if dataclasses.is_dataclass(value): return [encode_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)] + if is_namedtuple_type(type(value)): + return [encode_engine_value(getattr(value, name)) for name in value._fields] if isinstance(value, (list, tuple)): return [encode_engine_value(v) for v in value] if isinstance(value, dict): @@ -55,16 +57,16 @@ def make_engine_value_decoder( f"Type mismatch for `{''.join(field_path)}`: " f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})") - if dst_type_info.dataclass_type is not None: + if dst_type_info.struct_type is not None: return _make_engine_struct_value_decoder( - field_path, src_type['fields'], dst_type_info.dataclass_type) + field_path, src_type['fields'], dst_type_info.struct_type) if src_type_kind in TABLE_TYPES: field_path.append('[*]') elem_type_info = analyze_type_info(dst_type_info.elem_type) - if elem_type_info.dataclass_type is None: + if elem_type_info.struct_type is None: raise ValueError(f"Type mismatch for `{''.join(field_path)}`: " - f"declared `{dst_type_info.kind}`, a dataclass type expected") + f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected") engine_fields_schema = src_type['row']['fields'] if elem_type_info.key_type is not None: key_field_schema = engine_fields_schema[0] @@ -73,14 +75,14 @@ def make_engine_value_decoder( field_path, key_field_schema['type'], elem_type_info.key_type) field_path.pop() value_decoder = _make_engine_struct_value_decoder( - field_path, engine_fields_schema[1:], elem_type_info.dataclass_type) + field_path, engine_fields_schema[1:], elem_type_info.struct_type) def decode(value): if value is None: return None return {key_decoder(v[0]): value_decoder(v[1:]) for v in value} else: elem_decoder = _make_engine_struct_value_decoder( - field_path, engine_fields_schema, elem_type_info.dataclass_type) + field_path, engine_fields_schema, elem_type_info.struct_type) def decode(value): if value is None: return None @@ -96,11 +98,31 @@ def decode(value): def _make_engine_struct_value_decoder( field_path: list[str], src_fields: list[dict[str, Any]], - dst_dataclass_type: type, + dst_struct_type: type, ) -> Callable[[list], Any]: """Make a decoder from an engine field values to a Python value.""" src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)} + + is_dataclass = dataclasses.is_dataclass(dst_struct_type) + is_namedtuple = is_namedtuple_type(dst_struct_type) + + if is_dataclass: + parameters = inspect.signature(dst_struct_type).parameters + elif is_namedtuple: + defaults = getattr(dst_struct_type, '_field_defaults', {}) + parameters = { + name: inspect.Parameter( + name=name, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=defaults.get(name, inspect.Parameter.empty), + annotation=dst_struct_type.__annotations__.get(name, inspect.Parameter.empty) + ) + for name in dst_struct_type._fields + } + else: + raise ValueError(f"Unsupported struct type: {dst_struct_type}") + def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]: src_idx = src_name_to_idx.get(name) if src_idx is not None: @@ -108,7 +130,7 @@ def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[lis field_decoder = make_engine_value_decoder( field_path, src_fields[src_idx]['type'], param.annotation) field_path.pop() - return lambda values: field_decoder(values[src_idx]) + return lambda values: field_decoder(values[src_idx]) if len(values) > src_idx else param.default default_value = param.default if default_value is inspect.Parameter.empty: @@ -119,9 +141,9 @@ def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[lis field_value_decoder = [ make_closure_for_value(name, param) - for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()] + for (name, param) in parameters.items()] - return lambda values: dst_dataclass_type( + return lambda values: dst_struct_type( *(decoder(values) for decoder in field_value_decoder)) def dump_engine_object(v: Any) -> Any: diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 9019653b7..33c55ff9e 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -1,11 +1,12 @@ import uuid import datetime from dataclasses import dataclass, make_dataclass +from typing import NamedTuple, Literal import pytest import cocoindex from cocoindex.typing import encode_enriched_type from cocoindex.convert import encode_engine_value, make_engine_value_decoder -from typing import Literal + @dataclass class Order: order_id: str @@ -33,6 +34,17 @@ class NestedStruct: orders: list[Order] count: int = 0 +class OrderNamedTuple(NamedTuple): + order_id: str + name: str + price: float + extra_field: str = "default_extra" + +class CustomerNamedTuple(NamedTuple): + name: str + order: OrderNamedTuple + tags: list[Tag] | None = None + def build_engine_value_decoder(engine_type_in_py, python_type=None): """ Helper to build a converter for the given engine-side type (as represented in Python). @@ -62,10 +74,16 @@ def test_encode_engine_value_date_time_types(): def test_encode_engine_value_struct(): order = Order(order_id="O123", name="mixed nuts", price=25.0) assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"] + + order_nt = OrderNamedTuple(order_id="O123", name="mixed nuts", price=25.0) + assert encode_engine_value(order_nt) == ["O123", "mixed nuts", 25.0, "default_extra"] def test_encode_engine_value_list_of_structs(): orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)] assert encode_engine_value(orders) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]] + + orders_nt = [OrderNamedTuple("O1", "item1", 10.0), OrderNamedTuple("O2", "item2", 20.0)] + assert encode_engine_value(orders_nt) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]] def test_encode_engine_value_struct_with_list(): basket = Basket(items=["apple", "banana"]) @@ -74,6 +92,9 @@ def test_encode_engine_value_struct_with_list(): def test_encode_engine_value_nested_struct(): customer = Customer(name="Alice", order=Order("O1", "item1", 10.0)) assert encode_engine_value(customer) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None] + + customer_nt = CustomerNamedTuple(name="Alice", order=OrderNamedTuple("O1", "item1", 10.0)) + assert encode_engine_value(customer_nt) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None] def test_encode_engine_value_empty_list(): assert encode_engine_value([]) == [] @@ -103,20 +124,34 @@ def test_make_engine_value_decoder_basic_types(): @pytest.mark.parametrize( "data_type, engine_val, expected", [ - # All fields match + # All fields match (dataclass) (Order, ["O123", "mixed nuts", 25.0, "default_extra"], Order("O123", "mixed nuts", 25.0, "default_extra")), + # All fields match (NamedTuple) + (OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")), # Extra field in engine value (should ignore extra) (Order, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], Order("O123", "mixed nuts", 25.0, "default_extra")), + (OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")), # Fewer fields in engine value (should fill with default) (Order, ["O123", "mixed nuts", 0.0, "default_extra"], Order("O123", "mixed nuts", 0.0, "default_extra")), + (OrderNamedTuple, ["O123", "mixed nuts", 0.0, "default_extra"], OrderNamedTuple("O123", "mixed nuts", 0.0, "default_extra")), # More fields in engine value (should ignore extra) (Order, ["O123", "mixed nuts", 25.0, "unexpected"], Order("O123", "mixed nuts", 25.0, "unexpected")), + (OrderNamedTuple, ["O123", "mixed nuts", 25.0, "unexpected"], OrderNamedTuple("O123", "mixed nuts", 25.0, "unexpected")), # Truly extra field (should ignore the fifth field) (Order, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], Order("O123", "mixed nuts", 25.0, "default_extra")), + (OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")), # Missing optional field in engine value (tags=None) (Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)), + (CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None)), # Extra field in engine value for Customer (should ignore) (Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")])), + (CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), [Tag("vip")])), + # Missing optional field with default + (Order, ["O123", "mixed nuts", 25.0], Order("O123", "mixed nuts", 25.0, "default_extra")), + (OrderNamedTuple, ["O123", "mixed nuts", 25.0], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")), + # Partial optional fields + (Customer, ["Alice", ["O1", "item1", 10.0]], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)), + (CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0]], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None)), ] ) def test_struct_decoder_cases(data_type, engine_val, expected): @@ -124,17 +159,27 @@ def test_struct_decoder_cases(data_type, engine_val, expected): assert decoder(engine_val) == expected def test_make_engine_value_decoder_collections(): - # List of structs + # List of structs (dataclass) decoder = build_engine_value_decoder(list[Order]) engine_val = [ ["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"] ] assert decoder(engine_val) == [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")] + + # List of structs (NamedTuple) + decoder = build_engine_value_decoder(list[OrderNamedTuple]) + assert decoder(engine_val) == [OrderNamedTuple("O1", "item1", 10.0, "default_extra"), OrderNamedTuple("O2", "item2", 20.0, "default_extra")] + # Struct with list field decoder = build_engine_value_decoder(Customer) engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"], ["premium"]]] assert decoder(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")]) + + # NamedTuple with list field + decoder = build_engine_value_decoder(CustomerNamedTuple) + assert decoder(engine_val) == CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")]) + # Struct with struct field decoder = build_engine_value_decoder(NestedStruct) engine_val = [ @@ -239,6 +284,13 @@ def test_roundtrip_ltable(): assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]] decoded = build_engine_value_decoder(t)(encoded) assert decoded == value + + t_nt = list[OrderNamedTuple] + value_nt = [OrderNamedTuple("O1", "item1", 10.0), OrderNamedTuple("O2", "item2", 20.0)] + encoded = encode_engine_value(value_nt) + assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]] + decoded = build_engine_value_decoder(t_nt)(encoded) + assert decoded == value_nt def test_roundtrip_ktable_str_key(): t = dict[str, Order] @@ -247,6 +299,13 @@ def test_roundtrip_ktable_str_key(): assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]] decoded = build_engine_value_decoder(t)(encoded) assert decoded == value + + t_nt = dict[str, OrderNamedTuple] + value_nt = {"K1": OrderNamedTuple("O1", "item1", 10.0), "K2": OrderNamedTuple("O2", "item2", 20.0)} + encoded = encode_engine_value(value_nt) + assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]] + decoded = build_engine_value_decoder(t_nt)(encoded) + assert decoded == value_nt def test_roundtrip_ktable_struct_key(): @dataclass(frozen=True) @@ -261,6 +320,14 @@ class OrderKey: [["B", 4], "O2", "item2", 20.0, "default_extra"]] decoded = build_engine_value_decoder(t)(encoded) assert decoded == value + + t_nt = dict[OrderKey, OrderNamedTuple] + value_nt = {OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0), OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0)} + encoded = encode_engine_value(value_nt) + assert encoded == [[["A", 3], "O1", "item1", 10.0, "default_extra"], + [["B", 4], "O2", "item2", 20.0, "default_extra"]] + decoded = build_engine_value_decoder(t_nt)(encoded) + assert decoded == value_nt IntVectorType = cocoindex.Vector[int, Literal[5]] def test_vector_as_vector() -> None: diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index a6ec351b0..496ca5d21 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -56,8 +56,11 @@ def __class_getitem__(self, params): ElementType = type | tuple[type, type] +def is_namedtuple_type(t) -> bool: + return isinstance(t, type) and issubclass(t, tuple) and hasattr(t, "_fields") + def _is_struct_type(t) -> bool: - return isinstance(t, type) and dataclasses.is_dataclass(t) + return isinstance(t, type) and (dataclasses.is_dataclass(t) or is_namedtuple_type(t)) @dataclasses.dataclass class AnalyzedTypeInfo: @@ -69,7 +72,7 @@ class AnalyzedTypeInfo: elem_type: ElementType | None # For Vector and Table key_type: type | None # For element of KTable - dataclass_type: type | None # For Struct + struct_type: type | None # For Struct, a dataclass or namedtuple attrs: dict[str, Any] | None nullable: bool = False @@ -117,15 +120,16 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: elif isinstance(attr, TypeKind): kind = attr.kind - dataclass_type = None + struct_type = None elem_type = None key_type = None if _is_struct_type(t): + struct_type = t + if kind is None: kind = 'Struct' elif kind != 'Struct': raise ValueError(f"Unexpected type kind for struct: {kind}") - dataclass_type = t elif base_type is collections.abc.Sequence or base_type is list: args = typing.get_args(t) elem_type = args[0] @@ -167,36 +171,50 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: else: raise ValueError(f"type unsupported yet: {t}") - return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, - elem_type=elem_type, key_type=key_type, dataclass_type=dataclass_type, - attrs=attrs, nullable=nullable) - -def _encode_fields_schema(dataclass_type: type, key_type: type | None = None) -> list[dict[str, Any]]: + return AnalyzedTypeInfo( + kind=kind, + vector_info=vector_info, + elem_type=elem_type, + key_type=key_type, + struct_type=struct_type, + attrs=attrs, + nullable=nullable, + ) + +def _encode_fields_schema(struct_type: type, key_type: type | None = None) -> list[dict[str, Any]]: result = [] def add_field(name: str, t) -> None: try: type_info = encode_enriched_type_info(analyze_type_info(t)) except ValueError as e: - e.add_note(f"Failed to encode annotation for field - " - f"{dataclass_type.__name__}.{name}: {t}") + e.add_note( + f"Failed to encode annotation for field - " + f"{struct_type.__name__}.{name}: {t}" + ) raise type_info['name'] = name result.append(type_info) if key_type is not None: add_field(KEY_FIELD_NAME, key_type) - for field in dataclasses.fields(dataclass_type): - add_field(field.name, field.type) + + if dataclasses.is_dataclass(struct_type): + for field in dataclasses.fields(struct_type): + add_field(field.name, field.type) + elif is_namedtuple_type(struct_type): + for name, field_type in struct_type.__annotations__.items(): + add_field(name, field_type) + return result def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: encoded_type: dict[str, Any] = { 'kind': type_info.kind } if type_info.kind == 'Struct': - if type_info.dataclass_type is None: - raise ValueError("Struct type must have a dataclass type") - encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type, type_info.key_type) - if doc := inspect.getdoc(type_info.dataclass_type): + if type_info.struct_type is None: + raise ValueError("Struct type must have a dataclass or namedtuple type") + encoded_type['fields'] = _encode_fields_schema(type_info.struct_type, type_info.key_type) + if doc := inspect.getdoc(type_info.struct_type): encoded_type['description'] = doc elif type_info.kind == 'Vector':