Skip to content

Commit 316cdf3

Browse files
authored
feat: support scalar NumPy value encodings (#620)
* feat: support scalar numpy value encodings and conversions * feat: add numpy number type detection logic * feat: enhance vector decoding for non-NDArray types * fix: reorder imports and annotate return types * refactor: simplify DtypeRegistry structure and associated methods * feat: remove DtypeRegistry getter method
1 parent 23f8cc0 commit 316cdf3

File tree

4 files changed

+337
-166
lines changed

4 files changed

+337
-166
lines changed

python/cocoindex/convert.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@
66
import datetime
77
import inspect
88
import uuid
9+
from enum import Enum
10+
from typing import Any, Callable, Mapping, get_origin
11+
912
import numpy as np
1013

11-
from enum import Enum
12-
from typing import Any, Callable, get_origin, Mapping
1314
from .typing import (
15+
KEY_FIELD_NAME,
16+
TABLE_TYPES,
17+
DtypeRegistry,
1418
analyze_type_info,
1519
encode_enriched_type,
20+
extract_ndarray_scalar_dtype,
1621
is_namedtuple_type,
17-
TABLE_TYPES,
18-
KEY_FIELD_NAME,
19-
DtypeRegistry,
2022
)
2123

2224

@@ -29,6 +31,8 @@ def encode_engine_value(value: Any) -> Any:
2931
]
3032
if is_namedtuple_type(type(value)):
3133
return [encode_engine_value(getattr(value, name)) for name in value._fields]
34+
if isinstance(value, np.number):
35+
return value.item()
3236
if isinstance(value, np.ndarray):
3337
return value
3438
if isinstance(value, (list, tuple)):
@@ -86,6 +90,20 @@ def make_engine_value_decoder(
8690
field_path, src_type["fields"], dst_type_info.struct_type
8791
)
8892

93+
if dst_type_info.np_number_type is not None and src_type_kind != "Vector":
94+
numpy_type = dst_type_info.np_number_type
95+
96+
def decode_numpy_scalar(value: Any) -> Any | None:
97+
if value is None:
98+
if dst_type_info.nullable:
99+
return None
100+
raise ValueError(
101+
f"Received null for non-nullable scalar `{''.join(field_path)}`"
102+
)
103+
return numpy_type(value)
104+
105+
return decode_numpy_scalar
106+
89107
if src_type_kind in TABLE_TYPES:
90108
field_path.append("[*]")
91109
elem_type_info = analyze_type_info(dst_type_info.elem_type)
@@ -127,33 +145,42 @@ def decode(value: Any) -> Any | None:
127145
return lambda value: uuid.UUID(bytes=value)
128146

129147
if src_type_kind == "Vector":
130-
dtype_info = DtypeRegistry.get_by_dtype(dst_type_info.np_number_type)
131148

132149
def decode_vector(value: Any) -> Any | None:
150+
field_path_str = "".join(field_path)
151+
expected_dim = (
152+
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
153+
)
154+
133155
if value is None:
134156
if dst_type_info.nullable:
135157
return None
136158
raise ValueError(
137-
f"Received null for non-nullable vector `{''.join(field_path)}`"
159+
f"Received null for non-nullable vector `{field_path_str}`"
138160
)
139-
140161
if not isinstance(value, (np.ndarray, list)):
141162
raise TypeError(
142-
f"Expected NDArray or list for vector `{''.join(field_path)}`, got {type(value)}"
163+
f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
143164
)
144-
expected_dim = (
145-
dst_type_info.vector_info.dim if dst_type_info.vector_info else None
146-
)
147165
if expected_dim is not None and len(value) != expected_dim:
148166
raise ValueError(
149-
f"Vector dimension mismatch for `{''.join(field_path)}`: "
167+
f"Vector dimension mismatch for `{field_path_str}`: "
150168
f"expected {expected_dim}, got {len(value)}"
151169
)
152170

153-
# Use NDArray for supported numeric dtypes, else return list
154-
if dtype_info is not None:
155-
return np.array(value, dtype=dtype_info.numpy_dtype)
156-
return value
171+
if dst_type_info.np_number_type is None: # for Non-NDArray vector
172+
elem_decoder = make_engine_value_decoder(
173+
field_path + ["[*]"],
174+
src_type["element_type"],
175+
dst_type_info.elem_type,
176+
)
177+
return [elem_decoder(v) for v in value]
178+
else: # for NDArray vector
179+
scalar_dtype = extract_ndarray_scalar_dtype(
180+
dst_type_info.np_number_type
181+
)
182+
_ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
183+
return np.array(value, dtype=scalar_dtype)
157184

158185
return decode_vector
159186

python/cocoindex/tests/test_convert.py

Lines changed: 158 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1-
import uuid
21
import datetime
2+
import uuid
33
from dataclasses import dataclass, make_dataclass
4-
from typing import NamedTuple, Literal, Any, Callable, Union
4+
from typing import Annotated, Any, Callable, Literal, NamedTuple
5+
6+
import numpy as np
57
import pytest
8+
from numpy.typing import NDArray
9+
610
import cocoindex
7-
from cocoindex.typing import (
8-
encode_enriched_type,
9-
Vector,
10-
Float32,
11-
Float64,
12-
)
1311
from cocoindex.convert import (
12+
dump_engine_object,
1413
encode_engine_value,
1514
make_engine_value_decoder,
16-
dump_engine_object,
1715
)
18-
import numpy as np
19-
from numpy.typing import NDArray
16+
from cocoindex.typing import (
17+
Float32,
18+
Float64,
19+
TypeKind,
20+
Vector,
21+
encode_enriched_type,
22+
)
2023

2124

2225
@dataclass
@@ -128,6 +131,19 @@ def test_encode_engine_value_date_time_types() -> None:
128131
assert encode_engine_value(dt) == dt
129132

130133

134+
def test_encode_scalar_numpy_values() -> None:
135+
"""Test encoding scalar NumPy values to engine-compatible values."""
136+
test_cases = [
137+
(np.int64(42), 42),
138+
(np.float32(3.14), pytest.approx(3.14)),
139+
(np.float64(2.718), pytest.approx(2.718)),
140+
]
141+
for np_value, expected in test_cases:
142+
encoded = encode_engine_value(np_value)
143+
assert encoded == expected
144+
assert isinstance(encoded, (int, float))
145+
146+
131147
def test_encode_engine_value_struct() -> None:
132148
order = Order(order_id="O123", name="mixed nuts", price=25.0)
133149
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
@@ -213,6 +229,47 @@ def test_roundtrip_basic_types() -> None:
213229
)
214230

215231

232+
def test_decode_scalar_numpy_values() -> None:
233+
test_cases = [
234+
({"kind": "Int64"}, np.int64, 42, np.int64(42)),
235+
({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)),
236+
({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
237+
]
238+
for src_type, dst_type, input_value, expected in test_cases:
239+
decoder = make_engine_value_decoder(["field"], src_type, dst_type)
240+
result = decoder(input_value)
241+
assert isinstance(result, dst_type)
242+
assert result == expected
243+
244+
245+
def test_non_ndarray_vector_decoding() -> None:
246+
# Test list[np.float64]
247+
src_type = {
248+
"kind": "Vector",
249+
"element_type": {"kind": "Float64"},
250+
"dimension": None,
251+
}
252+
dst_type_float = list[np.float64]
253+
decoder = make_engine_value_decoder(["field"], src_type, dst_type_float)
254+
input_numbers = [1.0, 2.0, 3.0]
255+
result = decoder(input_numbers)
256+
assert isinstance(result, list)
257+
assert all(isinstance(x, np.float64) for x in result)
258+
assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)]
259+
260+
# Test list[Uuid]
261+
src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None}
262+
dst_type_uuid = list[uuid.UUID]
263+
decoder = make_engine_value_decoder(["field"], src_type, dst_type_uuid)
264+
uuid1 = uuid.uuid4()
265+
uuid2 = uuid.uuid4()
266+
input_bytes = [uuid1.bytes, uuid2.bytes]
267+
result = decoder(input_bytes)
268+
assert isinstance(result, list)
269+
assert all(isinstance(x, uuid.UUID) for x in result)
270+
assert result == [uuid1, uuid2]
271+
272+
216273
@pytest.mark.parametrize(
217274
"data_type, engine_val, expected",
218275
[
@@ -565,12 +622,6 @@ def test_vector_as_list() -> None:
565622
Float32VectorType = Vector[np.float32, Literal[3]]
566623
Float64VectorType = Vector[np.float64, Literal[3]]
567624
Int64VectorType = Vector[np.int64, Literal[3]]
568-
Int32VectorType = Vector[np.int32, Literal[3]]
569-
UInt8VectorType = Vector[np.uint8, Literal[3]]
570-
UInt16VectorType = Vector[np.uint16, Literal[3]]
571-
UInt32VectorType = Vector[np.uint32, Literal[3]]
572-
UInt64VectorType = Vector[np.uint64, Literal[3]]
573-
StrVectorType = Vector[str]
574625
NDArrayFloat32Type = NDArray[np.float32]
575626
NDArrayFloat64Type = NDArray[np.float64]
576627
NDArrayInt64Type = NDArray[np.int64]
@@ -767,19 +818,19 @@ def test_full_roundtrip_vector_numeric_types() -> None:
767818
value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64)
768819
validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]])
769820
value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32)
770-
with pytest.raises(ValueError, match="type unsupported yet"):
821+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
771822
validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]])
772823
value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8)
773-
with pytest.raises(ValueError, match="type unsupported yet"):
824+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
774825
validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]])
775826
value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16)
776-
with pytest.raises(ValueError, match="type unsupported yet"):
827+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
777828
validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]])
778829
value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32)
779-
with pytest.raises(ValueError, match="type unsupported yet"):
830+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
780831
validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]])
781832
value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64)
782-
with pytest.raises(ValueError, match="type unsupported yet"):
833+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
783834
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
784835

785836

@@ -808,7 +859,88 @@ def test_roundtrip_dimension_mismatch() -> None:
808859
validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]])
809860

810861

811-
def test_roundtrip_list_backward_compatibility() -> None:
812-
"""Test full roundtrip for list-based vectors for backward compatibility."""
813-
value_list: list[int] = [1, 2, 3]
814-
validate_full_roundtrip(value_list, list[int])
862+
def test_full_roundtrip_scalar_numeric_types() -> None:
863+
"""Test full roundtrip for scalar NumPy numeric types."""
864+
# Test supported scalar types
865+
validate_full_roundtrip(np.int64(42), np.int64)
866+
validate_full_roundtrip(np.float32(3.14), np.float32)
867+
validate_full_roundtrip(np.float64(2.718), np.float64)
868+
869+
# Test unsupported scalar types
870+
for unsupported_type in [np.int32, np.uint8, np.uint16, np.uint32, np.uint64]:
871+
with pytest.raises(ValueError, match="Unsupported NumPy dtype"):
872+
validate_full_roundtrip(unsupported_type(1), unsupported_type)
873+
874+
875+
def test_full_roundtrip_nullable_scalar() -> None:
876+
"""Test full roundtrip for nullable scalar NumPy types."""
877+
# Test with non-null values
878+
validate_full_roundtrip(np.int64(42), np.int64 | None)
879+
validate_full_roundtrip(np.float32(3.14), np.float32 | None)
880+
validate_full_roundtrip(np.float64(2.718), np.float64 | None)
881+
882+
# Test with None
883+
validate_full_roundtrip(None, np.int64 | None)
884+
validate_full_roundtrip(None, np.float32 | None)
885+
validate_full_roundtrip(None, np.float64 | None)
886+
887+
888+
def test_full_roundtrip_scalar_in_struct() -> None:
889+
"""Test full roundtrip for scalar NumPy types in a dataclass."""
890+
891+
@dataclass
892+
class NumericStruct:
893+
int_field: np.int64
894+
float32_field: np.float32
895+
float64_field: np.float64
896+
897+
instance = NumericStruct(
898+
int_field=np.int64(42),
899+
float32_field=np.float32(3.14),
900+
float64_field=np.float64(2.718),
901+
)
902+
validate_full_roundtrip(instance, NumericStruct)
903+
904+
905+
def test_full_roundtrip_scalar_in_nested_struct() -> None:
906+
"""Test full roundtrip for scalar NumPy types in a nested struct."""
907+
908+
@dataclass
909+
class InnerStruct:
910+
value: np.float64
911+
912+
@dataclass
913+
class OuterStruct:
914+
inner: InnerStruct
915+
count: np.int64
916+
917+
instance = OuterStruct(
918+
inner=InnerStruct(value=np.float64(2.718)),
919+
count=np.int64(1),
920+
)
921+
validate_full_roundtrip(instance, OuterStruct)
922+
923+
924+
def test_full_roundtrip_scalar_with_python_types() -> None:
925+
"""Test full roundtrip for structs mixing NumPy and Python scalar types."""
926+
927+
@dataclass
928+
class MixedStruct:
929+
numpy_int: np.int64
930+
python_int: int
931+
numpy_float: np.float64
932+
python_float: float
933+
string: str
934+
annotated_int: Annotated[np.int64, TypeKind("int")]
935+
annotated_float: Float32
936+
937+
instance = MixedStruct(
938+
numpy_int=np.int64(42),
939+
python_int=43,
940+
numpy_float=np.float64(2.718),
941+
python_float=3.14,
942+
string="hello, world",
943+
annotated_int=np.int64(42),
944+
annotated_float=2.0,
945+
)
946+
validate_full_roundtrip(instance, MixedStruct)

0 commit comments

Comments
 (0)