Skip to content

Commit e4dcb9a

Browse files
authored
feat: handle NumPy array vector in Python conversion (#595)
* feat: handle numpy array vector in Python conversion * test: add roundtrip tests for numeric and string vector types * feat: numeric type uint64 is unsupported * feat: support dtype decoding by adding `np_number_type` to `AnalyzedTypeInfo`
1 parent f85e686 commit e4dcb9a

File tree

7 files changed

+287
-62
lines changed

7 files changed

+287
-62
lines changed

Cargo.lock

Lines changed: 71 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ name = "cocoindex_engine"
1515
crate-type = ["cdylib"]
1616

1717
[dependencies]
18-
pyo3 = { version = "0.24.1", features = ["chrono"] }
19-
pythonize = "0.24.0"
20-
pyo3-async-runtimes = { version = "0.24.0", features = ["tokio-runtime"] }
18+
pyo3 = { version = "0.25.0", features = ["chrono"] }
19+
pythonize = "0.25.0"
20+
pyo3-async-runtimes = { version = "0.25.0", features = ["tokio-runtime"] }
2121

2222
anyhow = { version = "1.0.97", features = ["std"] }
2323
async-trait = "0.1.88"
@@ -114,3 +114,4 @@ json5 = "0.4.1"
114114
aws-config = "1.6.2"
115115
aws-sdk-s3 = "1.85.0"
116116
aws-sdk-sqs = "1.67.0"
117+
numpy = "0.25.0"

python/cocoindex/convert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ def decode(value: Any) -> Any | None:
127127
return lambda value: uuid.UUID(bytes=value)
128128

129129
if src_type_kind == "Vector":
130-
elem_coco_type_info = analyze_type_info(dst_type_info.elem_type)
131-
dtype_info = DtypeRegistry.get_by_kind(elem_coco_type_info.kind)
130+
dtype_info = DtypeRegistry.get_by_dtype(dst_type_info.np_number_type)
132131

133132
def decode_vector(value: Any) -> Any | None:
134133
if value is None:

python/cocoindex/tests/test_convert.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import uuid
22
import datetime
33
from dataclasses import dataclass, make_dataclass
4-
from typing import NamedTuple, Literal, Any, Callable
4+
from typing import NamedTuple, Literal, Any, Callable, Union
55
import pytest
66
import cocoindex
77
from cocoindex.typing import (
@@ -91,7 +91,7 @@ def validate_full_roundtrip(
9191
decoded_value = build_engine_value_decoder(input_type or output_type, output_type)(
9292
value_from_engine
9393
)
94-
assert decoded_value == value
94+
np.testing.assert_array_equal(decoded_value, value)
9595

9696

9797
def test_encode_engine_value_basic_types():
@@ -540,6 +540,11 @@ def test_vector_as_list() -> None:
540540
Float64VectorType = Vector[np.float64, Literal[3]]
541541
Int64VectorType = Vector[np.int64, Literal[3]]
542542
Int32VectorType = Vector[np.int32, Literal[3]]
543+
UInt8VectorType = Vector[np.uint8, Literal[3]]
544+
UInt16VectorType = Vector[np.uint16, Literal[3]]
545+
UInt32VectorType = Vector[np.uint32, Literal[3]]
546+
UInt64VectorType = Vector[np.uint64, Literal[3]]
547+
StrVectorType = Vector[str]
543548
NDArrayFloat32Type = NDArray[np.float32]
544549
NDArrayFloat64Type = NDArray[np.float64]
545550
NDArrayInt64Type = NDArray[np.int64]
@@ -635,15 +640,6 @@ def test_uint_support():
635640
decoded = decoder(encoded)
636641
assert np.array_equal(decoded, value_uint32)
637642
assert decoded.dtype == np.uint32
638-
value_uint64 = np.array([1, 2, 3], dtype=np.uint64)
639-
encoded = encode_engine_value(value_uint64)
640-
assert np.array_equal(encoded, [1, 2, 3])
641-
decoder = make_engine_value_decoder(
642-
[], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint64]
643-
)
644-
decoded = decoder(encoded)
645-
assert np.array_equal(decoded, value_uint64)
646-
assert decoded.dtype == np.uint64
647643

648644

649645
def test_ndarray_dimension_mismatch():
@@ -765,3 +761,59 @@ def test_dump_vector_type_annotation_no_dim():
765761
}
766762
}
767763
assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim
764+
765+
766+
def test_full_roundtrip_vector_numeric_types() -> None:
767+
"""Test full roundtrip for numeric vector types using NDArray."""
768+
value_f32: Vector[np.float32, Literal[3]] = np.array(
769+
[1.0, 2.0, 3.0], dtype=np.float32
770+
)
771+
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
772+
value_f64: Vector[np.float64, Literal[3]] = np.array(
773+
[1.0, 2.0, 3.0], dtype=np.float64
774+
)
775+
validate_full_roundtrip(value_f64, Vector[np.float64, Literal[3]])
776+
value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
777+
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
778+
value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
779+
validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
780+
value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
781+
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
782+
value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
783+
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
784+
value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
785+
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
786+
value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
787+
with pytest.raises(ValueError, match="type unsupported yet"):
788+
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
789+
790+
791+
def test_roundtrip_vector_no_dimension() -> None:
792+
"""Test full roundtrip for vector types without dimension annotation."""
793+
value_f64: Vector[np.float64] = np.array([1.0, 2.0, 3.0], dtype=np.float64)
794+
validate_full_roundtrip(value_f64, Vector[np.float64])
795+
796+
797+
def test_roundtrip_string_vector() -> None:
798+
"""Test full roundtrip for string vector using list."""
799+
value_str: Vector[str] = ["hello", "world"]
800+
validate_full_roundtrip(value_str, Vector[str])
801+
802+
803+
def test_roundtrip_empty_vector() -> None:
804+
"""Test full roundtrip for empty numeric vector."""
805+
value_empty: Vector[np.float32] = np.array([], dtype=np.float32)
806+
validate_full_roundtrip(value_empty, Vector[np.float32])
807+
808+
809+
def test_roundtrip_dimension_mismatch() -> None:
810+
"""Test that dimension mismatch raises an error during roundtrip."""
811+
value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32)
812+
with pytest.raises(ValueError, match="Vector dimension mismatch"):
813+
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
814+
815+
816+
def test_roundtrip_list_backward_compatibility() -> None:
817+
"""Test full roundtrip for list-based vectors for backward compatibility."""
818+
value_list: list[int] = [1, 2, 3]
819+
validate_full_roundtrip(value_list, list[int])

0 commit comments

Comments
 (0)