Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 71 additions & 15 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ name = "cocoindex_engine"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.24.1", features = ["chrono"] }
pythonize = "0.24.0"
pyo3-async-runtimes = { version = "0.24.0", features = ["tokio-runtime"] }
pyo3 = { version = "0.25.0", features = ["chrono"] }
pythonize = "0.25.0"
pyo3-async-runtimes = { version = "0.25.0", features = ["tokio-runtime"] }

anyhow = { version = "1.0.97", features = ["std"] }
async-trait = "0.1.88"
Expand Down Expand Up @@ -113,3 +113,4 @@ json5 = "0.4.1"
aws-config = "1.6.2"
aws-sdk-s3 = "1.85.0"
aws-sdk-sqs = "1.67.0"
numpy = "0.25.0"
64 changes: 62 additions & 2 deletions python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
import datetime
from dataclasses import dataclass, make_dataclass
from typing import NamedTuple, Literal, Any, Callable
from typing import NamedTuple, Literal, Any, Callable, Union
import pytest
import cocoindex
from cocoindex.typing import (
Expand Down Expand Up @@ -91,7 +91,7 @@ def validate_full_roundtrip(
decoded_value = build_engine_value_decoder(input_type or output_type, output_type)(
value_from_engine
)
assert decoded_value == value
np.testing.assert_array_equal(decoded_value, value)


def test_encode_engine_value_basic_types():
Expand Down Expand Up @@ -540,6 +540,11 @@ def test_vector_as_list() -> None:
Float64VectorType = Vector[np.float64, Literal[3]]
Int64VectorType = Vector[np.int64, Literal[3]]
Int32VectorType = Vector[np.int32, Literal[3]]
UInt8VectorType = Vector[np.uint8, Literal[3]]
UInt16VectorType = Vector[np.uint16, Literal[3]]
UInt32VectorType = Vector[np.uint32, Literal[3]]
UInt64VectorType = Vector[np.uint64, Literal[3]]
StrVectorType = Vector[str]
NDArrayFloat32Type = NDArray[np.float32]
NDArrayFloat64Type = NDArray[np.float64]
NDArrayInt64Type = NDArray[np.int64]
Expand Down Expand Up @@ -765,3 +770,58 @@ def test_dump_vector_type_annotation_no_dim():
}
}
assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim


def test_roundtrip_vector_numeric_types() -> None:
"""Test full roundtrip for numeric vector types using NDArray."""
value_f32: Vector[np.float32, Literal[3]] = np.array(
[1.0, 2.0, 3.0], dtype=np.float32
)
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
value_f64: Vector[np.float64, Literal[3]] = np.array(
[1.0, 2.0, 3.0], dtype=np.float64
)
validate_full_roundtrip(value_f64, Vector[np.float64, Literal[3]])
value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])


def test_roundtrip_vector_no_dimension() -> None:
"""Test full roundtrip for vector types without dimension annotation."""
value_f64: Vector[np.float64] = np.array([1.0, 2.0, 3.0], dtype=np.float64)
validate_full_roundtrip(value_f64, Vector[np.float64])


def test_roundtrip_string_vector() -> None:
"""Test full roundtrip for string vector using list."""
value_str: Vector[str] = ["hello", "world"]
validate_full_roundtrip(value_str, Vector[str])


def test_roundtrip_empty_vector() -> None:
"""Test full roundtrip for empty numeric vector."""
value_empty: Vector[np.float32] = np.array([], dtype=np.float32)
validate_full_roundtrip(value_empty, Vector[np.float32])


def test_roundtrip_dimension_mismatch() -> None:
"""Test that dimension mismatch raises an error during roundtrip."""
value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32)
with pytest.raises(ValueError, match="Vector dimension mismatch"):
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])


def test_roundtrip_list_backward_compatibility() -> None:
"""Test full roundtrip for list-based vectors for backward compatibility."""
value_list: list[int] = [1, 2, 3]
validate_full_roundtrip(value_list, list[int])
20 changes: 18 additions & 2 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None:


class DtypeRegistry:
"""
Registry for NumPy dtypes used in CocoIndex.
Provides mappings from NumPy dtypes to CocoIndex's type representation.
"""

_mappings: dict[type, DtypeInfo] = {
np.float32: DtypeInfo(np.float32, "Float32", float),
np.float64: DtypeInfo(np.float64, "Float64", float),
Expand All @@ -124,6 +129,7 @@ class DtypeRegistry:

@classmethod
def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
"""Get DtypeInfo by NumPy dtype."""
if dtype is Any:
raise TypeError(
"NDArray for Vector must use a concrete numpy dtype, got `Any`."
Expand All @@ -132,13 +138,21 @@ def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:

@staticmethod
def get_by_kind(kind: str) -> DtypeInfo | None:
"""Get DtypeInfo by kind."""
return next(
(info for info in DtypeRegistry._mappings.values() if info.kind == kind),
None,
)

@staticmethod
def rust_compatible_kind(kind: str) -> str:
"""Map to a Rust-compatible kind for schema encoding."""
# incompatible_integer_kinds = {"Int32", "UInt8", "UInt16", "UInt32", "UInt64"}
return "Int64" if "Int" in kind else kind

@staticmethod
def supported_dtypes() -> KeysView[type]:
"""Get a list of supported NumPy dtypes."""
return DtypeRegistry._mappings.keys()


Expand Down Expand Up @@ -340,8 +354,10 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
raise ValueError("Vector type must have a vector info")
if type_info.elem_type is None:
raise ValueError("Vector type must have an element type")
encoded_type["element_type"] = _encode_type(
analyze_type_info(type_info.elem_type)
elem_type_info = analyze_type_info(type_info.elem_type)
encoded_type["element_type"] = _encode_type(elem_type_info)
encoded_type["element_type"]["kind"] = DtypeRegistry.rust_compatible_kind(
elem_type_info.kind
)
encoded_type["dimension"] = type_info.vector_info.dim

Expand Down
Loading