Skip to content

Commit 509d7f4

Browse files
committed
feat: update engine value encoding to return ndarray directly
1 parent b07c180 commit 509d7f4

File tree

3 files changed

+29
-21
lines changed

3 files changed

+29
-21
lines changed

python/cocoindex/convert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def encode_engine_value(value: Any) -> Any:
3030
if is_namedtuple_type(type(value)):
3131
return [encode_engine_value(getattr(value, name)) for name in value._fields]
3232
if isinstance(value, np.ndarray):
33-
return value.tolist()
33+
return value
3434
if isinstance(value, (list, tuple)):
3535
return [encode_engine_value(v) for v in value]
3636
if isinstance(value, dict):
@@ -138,9 +138,9 @@ def decode_vector(value: Any) -> Any | None:
138138
f"Received null for non-nullable vector `{''.join(field_path)}`"
139139
)
140140

141-
if not isinstance(value, list):
141+
if not isinstance(value, (np.ndarray, list)):
142142
raise TypeError(
143-
f"Expected a list for vector `{''.join(field_path)}`, got {type(value)}"
143+
f"Expected NDArray or list for vector `{''.join(field_path)}`, got {type(value)}"
144144
)
145145
expected_dim = (
146146
dst_type_info.vector_info.dim if dst_type_info.vector_info else None

python/cocoindex/tests/test_convert.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -561,13 +561,13 @@ def test_vector_as_list() -> None:
561561
def test_encode_engine_value_ndarray():
562562
"""Test encoding NDArray vectors to lists for the Rust engine."""
563563
vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
564-
assert encode_engine_value(vec_f32) == [1.0, 2.0, 3.0]
564+
assert np.array_equal(encode_engine_value(vec_f32), [1.0, 2.0, 3.0])
565565
vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64)
566-
assert encode_engine_value(vec_f64) == [1.0, 2.0, 3.0]
566+
assert np.array_equal(encode_engine_value(vec_f64), [1.0, 2.0, 3.0])
567567
vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
568-
assert encode_engine_value(vec_i64) == [1, 2, 3]
568+
assert np.array_equal(encode_engine_value(vec_i64), [1, 2, 3])
569569
vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32)
570-
assert encode_engine_value(vec_nd_f32) == [1.0, 2.0, 3.0]
570+
assert np.array_equal(encode_engine_value(vec_nd_f32), [1.0, 2.0, 3.0])
571571

572572

573573
def test_make_engine_value_decoder_ndarray():
@@ -598,21 +598,21 @@ def test_roundtrip_ndarray_vector():
598598
"""Test roundtrip encoding and decoding of NDArray vectors."""
599599
value_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32)
600600
encoded_f32 = encode_engine_value(value_f32)
601-
assert encoded_f32 == [1.0, 2.0, 3.0]
601+
np.array_equal(encoded_f32, [1.0, 2.0, 3.0])
602602
decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32)
603603
assert isinstance(decoded_f32, np.ndarray)
604604
assert decoded_f32.dtype == np.float32
605605
assert np.array_equal(decoded_f32, value_f32)
606606
value_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64)
607607
encoded_i64 = encode_engine_value(value_i64)
608-
assert encoded_i64 == [1, 2, 3]
608+
assert np.array_equal(encoded_i64, [1, 2, 3])
609609
decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64)
610610
assert isinstance(decoded_i64, np.ndarray)
611611
assert decoded_i64.dtype == np.int64
612612
assert np.array_equal(decoded_i64, value_i64)
613613
value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64)
614614
encoded_nd_f64 = encode_engine_value(value_nd_f64)
615-
assert encoded_nd_f64 == [1.0, 2.0, 3.0]
615+
assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0])
616616
decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64)
617617
assert isinstance(decoded_nd_f64, np.ndarray)
618618
assert decoded_nd_f64.dtype == np.float64
@@ -623,7 +623,7 @@ def test_uint_support():
623623
"""Test encoding and decoding of unsigned integer vectors."""
624624
value_uint8 = np.array([1, 2, 3, 4], dtype=np.uint8)
625625
encoded = encode_engine_value(value_uint8)
626-
assert encoded == [1, 2, 3, 4]
626+
assert np.array_equal(encoded, [1, 2, 3, 4])
627627
decoder = make_engine_value_decoder(
628628
[], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint8]
629629
)
@@ -632,7 +632,7 @@ def test_uint_support():
632632
assert decoded.dtype == np.uint8
633633
value_uint16 = np.array([1, 2, 3, 4], dtype=np.uint16)
634634
encoded = encode_engine_value(value_uint16)
635-
assert encoded == [1, 2, 3, 4]
635+
assert np.array_equal(encoded, [1, 2, 3, 4])
636636
decoder = make_engine_value_decoder(
637637
[], {"kind": "Vector", "element_type": {"kind": "UInt16"}}, NDArray[np.uint16]
638638
)
@@ -641,7 +641,7 @@ def test_uint_support():
641641
assert decoded.dtype == np.uint16
642642
value_uint32 = np.array([1, 2, 3], dtype=np.uint32)
643643
encoded = encode_engine_value(value_uint32)
644-
assert encoded == [1, 2, 3]
644+
assert np.array_equal(encoded, [1, 2, 3])
645645
decoder = make_engine_value_decoder(
646646
[], {"kind": "Vector", "element_type": {"kind": "UInt32"}}, NDArray[np.uint32]
647647
)
@@ -650,7 +650,7 @@ def test_uint_support():
650650
assert decoded.dtype == np.uint32
651651
value_uint64 = np.array([1, 2, 3], dtype=np.uint64)
652652
encoded = encode_engine_value(value_uint64)
653-
assert encoded == [1, 2, 3]
653+
assert np.array_equal(encoded, [1, 2, 3])
654654
decoder = make_engine_value_decoder(
655655
[], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint64]
656656
)
@@ -663,7 +663,7 @@ def test_ndarray_dimension_mismatch():
663663
"""Test dimension enforcement for Vector with specified dimension."""
664664
value: Float32VectorType = np.array([1.0, 2.0], dtype=np.float32)
665665
encoded = encode_engine_value(value)
666-
assert encoded == [1.0, 2.0]
666+
assert np.array_equal(encoded, [1.0, 2.0])
667667
with pytest.raises(ValueError, match="Vector dimension mismatch"):
668668
build_engine_value_decoder(Float32VectorType)(encoded)
669669

@@ -679,9 +679,9 @@ def test_list_vector_backward_compatibility():
679679
assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
680680
value_list: ListIntType = [1, 2, 3, 4, 5]
681681
encoded = encode_engine_value(value_list)
682-
assert encoded == [1, 2, 3, 4, 5]
682+
assert np.array_equal(encoded, [1, 2, 3, 4, 5])
683683
decoded = build_engine_value_decoder(ListIntType)(encoded)
684-
assert decoded.tolist() == [1, 2, 3, 4, 5]
684+
assert np.array_equal(decoded, [1, 2, 3, 4, 5])
685685

686686

687687
def test_encode_complex_structure_with_ndarray():
@@ -702,7 +702,9 @@ class MyStructWithNDArray:
702702
[1.0, 0.5],
703703
100,
704704
]
705-
assert encoded == expected
705+
assert encoded[0] == expected[0]
706+
assert np.array_equal(encoded[1], expected[1])
707+
assert encoded[2] == expected[2]
706708

707709

708710
def test_decode_nullable_ndarray_none_or_value_input():
@@ -750,7 +752,7 @@ def test_decode_error_non_nullable_or_non_list_vector():
750752
decoder = make_engine_value_decoder([], src_type_dict, NDArrayFloat32Type)
751753
with pytest.raises(ValueError, match="Received null for non-nullable vector"):
752754
decoder(None)
753-
with pytest.raises(TypeError, match="Expected a list for vector"):
755+
with pytest.raises(TypeError, match="Expected NDArray or list for vector"):
754756
decoder("not a list")
755757

756758

python/cocoindex/typing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class Vector: # type: ignore[unreachable]
6464

6565
def __class_getitem__(self, params):
6666
if not isinstance(params, tuple):
67+
# No dimension provided, e.g., Vector[np.float32]
6768
dtype = params
6869
# Use NDArray for supported numeric dtypes, else list
6970
if DtypeRegistry.get_by_dtype(dtype) is not None:
@@ -78,7 +79,7 @@ def __class_getitem__(self, params):
7879
if typing.get_origin(dim_literal) is Literal
7980
else None
8081
)
81-
if dtype in DtypeRegistry.supported_dtypes():
82+
if DtypeRegistry.get_by_dtype(dtype) is not None:
8283
return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
8384
return Annotated[list[dtype], VectorInfo(dim=dim_val)]
8485

@@ -242,7 +243,12 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
242243

243244
numpy_dtype = dtype_args[0]
244245
dtype_info = DtypeRegistry.get_by_dtype(numpy_dtype)
245-
elem_type = None if dtype_info is None else dtype_info.annotated_type
246+
if dtype_info is None:
247+
raise ValueError(
248+
f"Unsupported numpy dtype for NDArray: {numpy_dtype}. "
249+
f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
250+
)
251+
elem_type = dtype_info.annotated_type
246252
vector_info = VectorInfo(dim=None) if vector_info is None else vector_info
247253

248254
elif base_type is collections.abc.Mapping or base_type is dict:

0 commit comments

Comments
 (0)