Skip to content

Commit 0586b8e

Browse files
authored
fix: remove unsupported type casting in integer vector handling (#619)
1 parent 91b1e84 commit 0586b8e

File tree

4 files changed

+12
-67
lines changed

4 files changed

+12
-67
lines changed

python/cocoindex/tests/test_convert.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ class OrderKey:
513513
validate_full_roundtrip(value_nt, t_nt)
514514

515515

516-
IntVectorType = cocoindex.Vector[np.int32, Literal[5]]
516+
IntVectorType = cocoindex.Vector[np.int64, Literal[5]]
517517

518518

519519
def test_vector_as_vector() -> None:
@@ -611,37 +611,6 @@ def test_roundtrip_ndarray_vector():
611611
assert np.array_equal(decoded_nd_f64, value_nd_f64)
612612

613613

614-
def test_uint_support():
615-
"""Test encoding and decoding of unsigned integer vectors."""
616-
value_uint8 = np.array([1, 2, 3, 4], dtype=np.uint8)
617-
encoded = encode_engine_value(value_uint8)
618-
assert np.array_equal(encoded, [1, 2, 3, 4])
619-
decoder = make_engine_value_decoder(
620-
[], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint8]
621-
)
622-
decoded = decoder(encoded)
623-
assert np.array_equal(decoded, value_uint8)
624-
assert decoded.dtype == np.uint8
625-
value_uint16 = np.array([1, 2, 3, 4], dtype=np.uint16)
626-
encoded = encode_engine_value(value_uint16)
627-
assert np.array_equal(encoded, [1, 2, 3, 4])
628-
decoder = make_engine_value_decoder(
629-
[], {"kind": "Vector", "element_type": {"kind": "UInt16"}}, NDArray[np.uint16]
630-
)
631-
decoded = decoder(encoded)
632-
assert np.array_equal(decoded, value_uint16)
633-
assert decoded.dtype == np.uint16
634-
value_uint32 = np.array([1, 2, 3], dtype=np.uint32)
635-
encoded = encode_engine_value(value_uint32)
636-
assert np.array_equal(encoded, [1, 2, 3])
637-
decoder = make_engine_value_decoder(
638-
[], {"kind": "Vector", "element_type": {"kind": "UInt32"}}, NDArray[np.uint32]
639-
)
640-
decoded = decoder(encoded)
641-
assert np.array_equal(decoded, value_uint32)
642-
assert decoded.dtype == np.uint32
643-
644-
645614
def test_ndarray_dimension_mismatch():
646615
"""Test dimension enforcement for Vector with specified dimension."""
647616
value: Float32VectorType = np.array([1.0, 2.0], dtype=np.float32)
@@ -658,7 +627,7 @@ def test_list_vector_backward_compatibility():
658627
assert encoded == [1, 2, 3, 4, 5]
659628
decoded = build_engine_value_decoder(IntVectorType)(encoded)
660629
assert isinstance(decoded, np.ndarray)
661-
assert decoded.dtype == np.int32
630+
assert decoded.dtype == np.int64
662631
assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64))
663632
value_list: ListIntType = [1, 2, 3, 4, 5]
664633
encoded = encode_engine_value(value_list)
@@ -773,16 +742,20 @@ def test_full_roundtrip_vector_numeric_types() -> None:
773742
[1.0, 2.0, 3.0], dtype=np.float64
774743
)
775744
validate_full_roundtrip(value_f64, Vector[np.float64, Literal[3]])
776-
value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
777-
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
778745
value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
779746
validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
747+
value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
748+
with pytest.raises(ValueError, match="type unsupported yet"):
749+
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
780750
value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
781-
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
751+
with pytest.raises(ValueError, match="type unsupported yet"):
752+
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
782753
value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
783-
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
754+
with pytest.raises(ValueError, match="type unsupported yet"):
755+
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
784756
value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
785-
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
757+
with pytest.raises(ValueError, match="type unsupported yet"):
758+
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
786759
value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
787760
with pytest.raises(ValueError, match="type unsupported yet"):
788761
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])

python/cocoindex/tests/test_typing.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,24 +108,6 @@ def test_ndarray_int64_no_dim():
108108
assert not result.nullable
109109

110110

111-
def test_ndarray_int32_with_dim():
112-
typ = Annotated[NDArray[np.int32], VectorInfo(dim=10)]
113-
result = analyze_type_info(typ)
114-
assert result.kind == "Vector"
115-
assert result.vector_info == VectorInfo(dim=10)
116-
assert get_args(result.elem_type) == (int, TypeKind("Int64"))
117-
assert not result.nullable
118-
119-
120-
def test_ndarray_uint8_no_dim():
121-
typ = NDArray[np.uint8]
122-
result = analyze_type_info(typ)
123-
assert result.kind == "Vector"
124-
assert result.vector_info == VectorInfo(dim=None)
125-
assert get_args(result.elem_type) == (int, TypeKind("Int64"))
126-
assert not result.nullable
127-
128-
129111
def test_nullable_ndarray():
130112
typ = NDArray[np.float32] | None
131113
result = analyze_type_info(typ)

python/cocoindex/typing.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,7 @@ class DtypeRegistry:
119119
_mappings: dict[type, DtypeInfo] = {
120120
np.float32: DtypeInfo(np.float32, "Float32", float),
121121
np.float64: DtypeInfo(np.float64, "Float64", float),
122-
np.int32: DtypeInfo(np.int32, "Int64", int),
123122
np.int64: DtypeInfo(np.int64, "Int64", int),
124-
np.uint8: DtypeInfo(np.uint8, "Int64", int),
125-
np.uint16: DtypeInfo(np.uint16, "Int64", int),
126-
np.uint32: DtypeInfo(np.uint32, "Int64", int),
127123
}
128124

129125
@classmethod

src/py/convert.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,7 @@ fn handle_ndarray_from_py<'py>(
184184
match elem_type {
185185
&schema::BasicValueType::Float32 => try_convert!(f32, value::BasicValue::Float32),
186186
&schema::BasicValueType::Float64 => try_convert!(f64, value::BasicValue::Float64),
187-
&schema::BasicValueType::Int64 => {
188-
try_convert!(i32, |v| value::BasicValue::Int64(v as i64));
189-
try_convert!(i64, value::BasicValue::Int64);
190-
try_convert!(u8, |v| value::BasicValue::Int64(v as i64));
191-
try_convert!(u16, |v| value::BasicValue::Int64(v as i64));
192-
try_convert!(u32, |v| value::BasicValue::Int64(v as i64));
193-
}
187+
&schema::BasicValueType::Int64 => try_convert!(i64, value::BasicValue::Int64),
194188
_ => {}
195189
}
196190

0 commit comments

Comments
 (0)