|
1 | | -import uuid |
2 | 1 | import datetime |
| 2 | +import uuid |
3 | 3 | from dataclasses import dataclass, make_dataclass |
4 | | -from typing import NamedTuple, Literal, Any, Callable, Union |
| 4 | +from typing import Annotated, Any, Callable, Literal, NamedTuple |
| 5 | + |
| 6 | +import numpy as np |
5 | 7 | import pytest |
| 8 | +from numpy.typing import NDArray |
| 9 | + |
6 | 10 | import cocoindex |
7 | | -from cocoindex.typing import ( |
8 | | - encode_enriched_type, |
9 | | - Vector, |
10 | | - Float32, |
11 | | - Float64, |
12 | | -) |
13 | 11 | from cocoindex.convert import ( |
| 12 | + dump_engine_object, |
14 | 13 | encode_engine_value, |
15 | 14 | make_engine_value_decoder, |
16 | | - dump_engine_object, |
17 | 15 | ) |
18 | | -import numpy as np |
19 | | -from numpy.typing import NDArray |
| 16 | +from cocoindex.typing import ( |
| 17 | + Float32, |
| 18 | + Float64, |
| 19 | + TypeKind, |
| 20 | + Vector, |
| 21 | + encode_enriched_type, |
| 22 | +) |
20 | 23 |
|
21 | 24 |
|
22 | 25 | @dataclass |
@@ -128,6 +131,19 @@ def test_encode_engine_value_date_time_types() -> None: |
128 | 131 | assert encode_engine_value(dt) == dt |
129 | 132 |
|
130 | 133 |
|
| 134 | +def test_encode_scalar_numpy_values() -> None: |
| 135 | + """Test encoding scalar NumPy values to engine-compatible values.""" |
| 136 | + test_cases = [ |
| 137 | + (np.int64(42), 42), |
| 138 | + (np.float32(3.14), pytest.approx(3.14)), |
| 139 | + (np.float64(2.718), pytest.approx(2.718)), |
| 140 | + ] |
| 141 | + for np_value, expected in test_cases: |
| 142 | + encoded = encode_engine_value(np_value) |
| 143 | + assert encoded == expected |
| 144 | + assert isinstance(encoded, (int, float)) |
| 145 | + |
| 146 | + |
131 | 147 | def test_encode_engine_value_struct() -> None: |
132 | 148 | order = Order(order_id="O123", name="mixed nuts", price=25.0) |
133 | 149 | assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"] |
@@ -213,6 +229,47 @@ def test_roundtrip_basic_types() -> None: |
213 | 229 | ) |
214 | 230 |
|
215 | 231 |
|
| 232 | +def test_decode_scalar_numpy_values() -> None: |
| 233 | + test_cases = [ |
| 234 | + ({"kind": "Int64"}, np.int64, 42, np.int64(42)), |
| 235 | + ({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)), |
| 236 | + ({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)), |
| 237 | + ] |
| 238 | + for src_type, dst_type, input_value, expected in test_cases: |
| 239 | + decoder = make_engine_value_decoder(["field"], src_type, dst_type) |
| 240 | + result = decoder(input_value) |
| 241 | + assert isinstance(result, dst_type) |
| 242 | + assert result == expected |
| 243 | + |
| 244 | + |
| 245 | +def test_non_ndarray_vector_decoding() -> None: |
| 246 | + # Test list[np.float64] |
| 247 | + src_type = { |
| 248 | + "kind": "Vector", |
| 249 | + "element_type": {"kind": "Float64"}, |
| 250 | + "dimension": None, |
| 251 | + } |
| 252 | + dst_type_float = list[np.float64] |
| 253 | + decoder = make_engine_value_decoder(["field"], src_type, dst_type_float) |
| 254 | + input_numbers = [1.0, 2.0, 3.0] |
| 255 | + result = decoder(input_numbers) |
| 256 | + assert isinstance(result, list) |
| 257 | + assert all(isinstance(x, np.float64) for x in result) |
| 258 | + assert result == [np.float64(1.0), np.float64(2.0), np.float64(3.0)] |
| 259 | + |
| 260 | + # Test list[Uuid] |
| 261 | + src_type = {"kind": "Vector", "element_type": {"kind": "Uuid"}, "dimension": None} |
| 262 | + dst_type_uuid = list[uuid.UUID] |
| 263 | + decoder = make_engine_value_decoder(["field"], src_type, dst_type_uuid) |
| 264 | + uuid1 = uuid.uuid4() |
| 265 | + uuid2 = uuid.uuid4() |
| 266 | + input_bytes = [uuid1.bytes, uuid2.bytes] |
| 267 | + result = decoder(input_bytes) |
| 268 | + assert isinstance(result, list) |
| 269 | + assert all(isinstance(x, uuid.UUID) for x in result) |
| 270 | + assert result == [uuid1, uuid2] |
| 271 | + |
| 272 | + |
216 | 273 | @pytest.mark.parametrize( |
217 | 274 | "data_type, engine_val, expected", |
218 | 275 | [ |
@@ -565,12 +622,6 @@ def test_vector_as_list() -> None: |
565 | 622 | Float32VectorType = Vector[np.float32, Literal[3]] |
566 | 623 | Float64VectorType = Vector[np.float64, Literal[3]] |
567 | 624 | Int64VectorType = Vector[np.int64, Literal[3]] |
568 | | -Int32VectorType = Vector[np.int32, Literal[3]] |
569 | | -UInt8VectorType = Vector[np.uint8, Literal[3]] |
570 | | -UInt16VectorType = Vector[np.uint16, Literal[3]] |
571 | | -UInt32VectorType = Vector[np.uint32, Literal[3]] |
572 | | -UInt64VectorType = Vector[np.uint64, Literal[3]] |
573 | | -StrVectorType = Vector[str] |
574 | 625 | NDArrayFloat32Type = NDArray[np.float32] |
575 | 626 | NDArrayFloat64Type = NDArray[np.float64] |
576 | 627 | NDArrayInt64Type = NDArray[np.int64] |
@@ -767,19 +818,19 @@ def test_full_roundtrip_vector_numeric_types() -> None: |
767 | 818 | value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64) |
768 | 819 | validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]]) |
769 | 820 | value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32) |
770 | | - with pytest.raises(ValueError, match="type unsupported yet"): |
| 821 | + with pytest.raises(ValueError, match="Unsupported NumPy dtype"): |
771 | 822 | validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]]) |
772 | 823 | value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8) |
773 | | - with pytest.raises(ValueError, match="type unsupported yet"): |
| 824 | + with pytest.raises(ValueError, match="Unsupported NumPy dtype"): |
774 | 825 | validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]]) |
775 | 826 | value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16) |
776 | | - with pytest.raises(ValueError, match="type unsupported yet"): |
| 827 | + with pytest.raises(ValueError, match="Unsupported NumPy dtype"): |
777 | 828 | validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]]) |
778 | 829 | value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32) |
779 | | - with pytest.raises(ValueError, match="type unsupported yet"): |
| 830 | + with pytest.raises(ValueError, match="Unsupported NumPy dtype"): |
780 | 831 | validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]]) |
781 | 832 | value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64) |
782 | | - with pytest.raises(ValueError, match="type unsupported yet"): |
| 833 | + with pytest.raises(ValueError, match="Unsupported NumPy dtype"): |
783 | 834 | validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]]) |
784 | 835 |
|
785 | 836 |
|
@@ -808,7 +859,88 @@ def test_roundtrip_dimension_mismatch() -> None: |
808 | 859 | validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]]) |
809 | 860 |
|
810 | 861 |
|
811 | | -def test_roundtrip_list_backward_compatibility() -> None: |
812 | | - """Test full roundtrip for list-based vectors for backward compatibility.""" |
813 | | - value_list: list[int] = [1, 2, 3] |
814 | | - validate_full_roundtrip(value_list, list[int]) |
| 862 | +def test_full_roundtrip_scalar_numeric_types() -> None: |
| 863 | + """Test full roundtrip for scalar NumPy numeric types.""" |
| 864 | + # Test supported scalar types |
| 865 | + validate_full_roundtrip(np.int64(42), np.int64) |
| 866 | + validate_full_roundtrip(np.float32(3.14), np.float32) |
| 867 | + validate_full_roundtrip(np.float64(2.718), np.float64) |
| 868 | + |
| 869 | + # Test unsupported scalar types |
| 870 | + for unsupported_type in [np.int32, np.uint8, np.uint16, np.uint32, np.uint64]: |
| 871 | + with pytest.raises(ValueError, match="Unsupported NumPy dtype"): |
| 872 | + validate_full_roundtrip(unsupported_type(1), unsupported_type) |
| 873 | + |
| 874 | + |
| 875 | +def test_full_roundtrip_nullable_scalar() -> None: |
| 876 | + """Test full roundtrip for nullable scalar NumPy types.""" |
| 877 | + # Test with non-null values |
| 878 | + validate_full_roundtrip(np.int64(42), np.int64 | None) |
| 879 | + validate_full_roundtrip(np.float32(3.14), np.float32 | None) |
| 880 | + validate_full_roundtrip(np.float64(2.718), np.float64 | None) |
| 881 | + |
| 882 | + # Test with None |
| 883 | + validate_full_roundtrip(None, np.int64 | None) |
| 884 | + validate_full_roundtrip(None, np.float32 | None) |
| 885 | + validate_full_roundtrip(None, np.float64 | None) |
| 886 | + |
| 887 | + |
| 888 | +def test_full_roundtrip_scalar_in_struct() -> None: |
| 889 | + """Test full roundtrip for scalar NumPy types in a dataclass.""" |
| 890 | + |
| 891 | + @dataclass |
| 892 | + class NumericStruct: |
| 893 | + int_field: np.int64 |
| 894 | + float32_field: np.float32 |
| 895 | + float64_field: np.float64 |
| 896 | + |
| 897 | + instance = NumericStruct( |
| 898 | + int_field=np.int64(42), |
| 899 | + float32_field=np.float32(3.14), |
| 900 | + float64_field=np.float64(2.718), |
| 901 | + ) |
| 902 | + validate_full_roundtrip(instance, NumericStruct) |
| 903 | + |
| 904 | + |
| 905 | +def test_full_roundtrip_scalar_in_nested_struct() -> None: |
| 906 | + """Test full roundtrip for scalar NumPy types in a nested struct.""" |
| 907 | + |
| 908 | + @dataclass |
| 909 | + class InnerStruct: |
| 910 | + value: np.float64 |
| 911 | + |
| 912 | + @dataclass |
| 913 | + class OuterStruct: |
| 914 | + inner: InnerStruct |
| 915 | + count: np.int64 |
| 916 | + |
| 917 | + instance = OuterStruct( |
| 918 | + inner=InnerStruct(value=np.float64(2.718)), |
| 919 | + count=np.int64(1), |
| 920 | + ) |
| 921 | + validate_full_roundtrip(instance, OuterStruct) |
| 922 | + |
| 923 | + |
| 924 | +def test_full_roundtrip_scalar_with_python_types() -> None: |
| 925 | + """Test full roundtrip for structs mixing NumPy and Python scalar types.""" |
| 926 | + |
| 927 | + @dataclass |
| 928 | + class MixedStruct: |
| 929 | + numpy_int: np.int64 |
| 930 | + python_int: int |
| 931 | + numpy_float: np.float64 |
| 932 | + python_float: float |
| 933 | + string: str |
| 934 | + annotated_int: Annotated[np.int64, TypeKind("int")] |
| 935 | + annotated_float: Float32 |
| 936 | + |
| 937 | + instance = MixedStruct( |
| 938 | + numpy_int=np.int64(42), |
| 939 | + python_int=43, |
| 940 | + numpy_float=np.float64(2.718), |
| 941 | + python_float=3.14, |
| 942 | + string="hello, world", |
| 943 | + annotated_int=np.int64(42), |
| 944 | + annotated_float=2.0, |
| 945 | + ) |
| 946 | + validate_full_roundtrip(instance, MixedStruct) |
0 commit comments