Skip to content

Commit 0f4e843

Browse files
committed
test: add roundtrip tests for numeric and string vector types
1 parent 246cdd2 commit 0f4e843

File tree

3 files changed

+81
-4
lines changed

3 files changed

+81
-4
lines changed

python/cocoindex/tests/test_convert.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import uuid
22
import datetime
33
from dataclasses import dataclass, make_dataclass
4-
from typing import NamedTuple, Literal, Any, Callable
4+
from typing import NamedTuple, Literal, Any, Callable, Union
55
import pytest
66
import cocoindex
77
from cocoindex.typing import (
@@ -91,7 +91,7 @@ def validate_full_roundtrip(
9191
decoded_value = build_engine_value_decoder(input_type or output_type, output_type)(
9292
value_from_engine
9393
)
94-
assert decoded_value == value
94+
np.testing.assert_array_equal(decoded_value, value)
9595

9696

9797
def test_encode_engine_value_basic_types():
@@ -540,6 +540,11 @@ def test_vector_as_list() -> None:
540540
Float64VectorType = Vector[np.float64, Literal[3]]
541541
Int64VectorType = Vector[np.int64, Literal[3]]
542542
Int32VectorType = Vector[np.int32, Literal[3]]
543+
UInt8VectorType = Vector[np.uint8, Literal[3]]
544+
UInt16VectorType = Vector[np.uint16, Literal[3]]
545+
UInt32VectorType = Vector[np.uint32, Literal[3]]
546+
UInt64VectorType = Vector[np.uint64, Literal[3]]
547+
StrVectorType = Vector[str]
543548
NDArrayFloat32Type = NDArray[np.float32]
544549
NDArrayFloat64Type = NDArray[np.float64]
545550
NDArrayInt64Type = NDArray[np.int64]
@@ -765,3 +770,58 @@ def test_dump_vector_type_annotation_no_dim():
765770
}
766771
}
767772
assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim
773+
774+
775+
def test_roundtrip_vector_numeric_types() -> None:
776+
"""Test full roundtrip for numeric vector types using NDArray."""
777+
value_f32: Vector[np.float32, Literal[3]] = np.array(
778+
[1.0, 2.0, 3.0], dtype=np.float32
779+
)
780+
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
781+
value_f64: Vector[np.float64, Literal[3]] = np.array(
782+
[1.0, 2.0, 3.0], dtype=np.float64
783+
)
784+
validate_full_roundtrip(value_f64, Vector[np.float64, Literal[3]])
785+
value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
786+
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
787+
value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
788+
validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
789+
value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
790+
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
791+
value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
792+
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
793+
value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
794+
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
795+
value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
796+
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
797+
798+
799+
def test_roundtrip_vector_no_dimension() -> None:
800+
"""Test full roundtrip for vector types without dimension annotation."""
801+
value_f64: Vector[np.float64] = np.array([1.0, 2.0, 3.0], dtype=np.float64)
802+
validate_full_roundtrip(value_f64, Vector[np.float64])
803+
804+
805+
def test_roundtrip_string_vector() -> None:
806+
"""Test full roundtrip for string vector using list."""
807+
value_str: Vector[str] = ["hello", "world"]
808+
validate_full_roundtrip(value_str, Vector[str])
809+
810+
811+
def test_roundtrip_empty_vector() -> None:
812+
"""Test full roundtrip for empty numeric vector."""
813+
value_empty: Vector[np.float32] = np.array([], dtype=np.float32)
814+
validate_full_roundtrip(value_empty, Vector[np.float32])
815+
816+
817+
def test_roundtrip_dimension_mismatch() -> None:
818+
"""Test that dimension mismatch raises an error during roundtrip."""
819+
value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32)
820+
with pytest.raises(ValueError, match="Vector dimension mismatch"):
821+
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
822+
823+
824+
def test_roundtrip_list_backward_compatibility() -> None:
825+
"""Test full roundtrip for list-based vectors for backward compatibility."""
826+
value_list: list[int] = [1, 2, 3]
827+
validate_full_roundtrip(value_list, list[int])

python/cocoindex/typing.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None:
111111

112112

113113
class DtypeRegistry:
114+
"""
115+
Registry for NumPy dtypes used in CocoIndex.
116+
Provides mappings from NumPy dtypes to CocoIndex's type representation.
117+
"""
118+
114119
_mappings: dict[type, DtypeInfo] = {
115120
np.float32: DtypeInfo(np.float32, "Float32", float),
116121
np.float64: DtypeInfo(np.float64, "Float64", float),
@@ -124,6 +129,7 @@ class DtypeRegistry:
124129

125130
@classmethod
126131
def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
132+
"""Get DtypeInfo by NumPy dtype."""
127133
if dtype is Any:
128134
raise TypeError(
129135
"NDArray for Vector must use a concrete numpy dtype, got `Any`."
@@ -132,13 +138,21 @@ def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
132138

133139
@staticmethod
134140
def get_by_kind(kind: str) -> DtypeInfo | None:
141+
"""Get DtypeInfo by kind."""
135142
return next(
136143
(info for info in DtypeRegistry._mappings.values() if info.kind == kind),
137144
None,
138145
)
139146

147+
@staticmethod
148+
def rust_compatible_kind(kind: str) -> str:
149+
"""Map to a Rust-compatible kind for schema encoding."""
150+
# incompatible_integer_kinds = {"Int32", "UInt8", "UInt16", "UInt32", "UInt64"}
151+
return "Int64" if "Int" in kind else kind
152+
140153
@staticmethod
141154
def supported_dtypes() -> KeysView[type]:
155+
"""Get a list of supported NumPy dtypes."""
142156
return DtypeRegistry._mappings.keys()
143157

144158

@@ -340,8 +354,10 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
340354
raise ValueError("Vector type must have a vector info")
341355
if type_info.elem_type is None:
342356
raise ValueError("Vector type must have an element type")
343-
encoded_type["element_type"] = _encode_type(
344-
analyze_type_info(type_info.elem_type)
357+
elem_type_info = analyze_type_info(type_info.elem_type)
358+
encoded_type["element_type"] = _encode_type(elem_type_info)
359+
encoded_type["element_type"]["kind"] = DtypeRegistry.rust_compatible_kind(
360+
elem_type_info.kind
345361
)
346362
encoded_type["dimension"] = type_info.vector_info.dim
347363

src/py/convert.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ fn basic_value_from_py_object<'py>(
165165
Ok(result)
166166
}
167167

168+
// Helper function to convert PyAny to BasicValue for NDArray
168169
fn handle_ndarray_from_py<'py>(
169170
elem_type: &schema::BasicValueType,
170171
v: &Bound<'py, PyAny>,

0 commit comments

Comments
 (0)