|
1 | 1 | import uuid |
2 | 2 | import datetime |
3 | 3 | from dataclasses import dataclass, make_dataclass |
4 | | -from typing import NamedTuple, Literal, Any, Callable |
| 4 | +from typing import NamedTuple, Literal, Any, Callable, Union |
5 | 5 | import pytest |
6 | 6 | import cocoindex |
7 | 7 | from cocoindex.typing import ( |
@@ -91,7 +91,7 @@ def validate_full_roundtrip( |
91 | 91 | decoded_value = build_engine_value_decoder(input_type or output_type, output_type)( |
92 | 92 | value_from_engine |
93 | 93 | ) |
94 | | - assert decoded_value == value |
| 94 | + np.testing.assert_array_equal(decoded_value, value) |
95 | 95 |
|
96 | 96 |
|
97 | 97 | def test_encode_engine_value_basic_types(): |
@@ -540,6 +540,11 @@ def test_vector_as_list() -> None: |
540 | 540 | Float64VectorType = Vector[np.float64, Literal[3]] |
541 | 541 | Int64VectorType = Vector[np.int64, Literal[3]] |
542 | 542 | Int32VectorType = Vector[np.int32, Literal[3]] |
| 543 | +UInt8VectorType = Vector[np.uint8, Literal[3]] |
| 544 | +UInt16VectorType = Vector[np.uint16, Literal[3]] |
| 545 | +UInt32VectorType = Vector[np.uint32, Literal[3]] |
| 546 | +UInt64VectorType = Vector[np.uint64, Literal[3]] |
| 547 | +StrVectorType = Vector[str] |
543 | 548 | NDArrayFloat32Type = NDArray[np.float32] |
544 | 549 | NDArrayFloat64Type = NDArray[np.float64] |
545 | 550 | NDArrayInt64Type = NDArray[np.int64] |
@@ -635,15 +640,6 @@ def test_uint_support(): |
635 | 640 | decoded = decoder(encoded) |
636 | 641 | assert np.array_equal(decoded, value_uint32) |
637 | 642 | assert decoded.dtype == np.uint32 |
638 | | - value_uint64 = np.array([1, 2, 3], dtype=np.uint64) |
639 | | - encoded = encode_engine_value(value_uint64) |
640 | | - assert np.array_equal(encoded, [1, 2, 3]) |
641 | | - decoder = make_engine_value_decoder( |
642 | | - [], {"kind": "Vector", "element_type": {"kind": "UInt8"}}, NDArray[np.uint64] |
643 | | - ) |
644 | | - decoded = decoder(encoded) |
645 | | - assert np.array_equal(decoded, value_uint64) |
646 | | - assert decoded.dtype == np.uint64 |
647 | 643 |
|
648 | 644 |
|
649 | 645 | def test_ndarray_dimension_mismatch(): |
@@ -765,3 +761,59 @@ def test_dump_vector_type_annotation_no_dim(): |
765 | 761 | } |
766 | 762 | } |
767 | 763 | assert dump_engine_object(Float64VectorTypeNoDim) == expected_dump_no_dim |
| 764 | + |
| 765 | + |
| 766 | +def test_full_roundtrip_vector_numeric_types() -> None: |
| 767 | + """Test full roundtrip for numeric vector types using NDArray.""" |
| 768 | + value_f32: Vector[np.float32, Literal[3]] = np.array( |
| 769 | + [1.0, 2.0, 3.0], dtype=np.float32 |
| 770 | + ) |
| 771 | + validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]]) |
| 772 | + value_f64: Vector[np.float64, Literal[3]] = np.array( |
| 773 | + [1.0, 2.0, 3.0], dtype=np.float64 |
| 774 | + ) |
| 775 | + validate_full_roundtrip(value_f64, Vector[np.float64, Literal[3]]) |
| 776 | + value_i32: Vector[np.int32, Literal[3]] = np.array([1, 2, 3], dtype=np.int32) |
| 777 | + validate_full_roundtrip(value_i32, Vector[np.int32, Literal[3]]) |
| 778 | + value_i64: Vector[np.int64, Literal[3]] = np.array([1, 2, 3], dtype=np.int64) |
| 779 | + validate_full_roundtrip(value_i64, Vector[np.int64, Literal[3]]) |
| 780 | + value_u8: Vector[np.uint8, Literal[3]] = np.array([1, 2, 3], dtype=np.uint8) |
| 781 | + validate_full_roundtrip(value_u8, Vector[np.uint8, Literal[3]]) |
| 782 | + value_u16: Vector[np.uint16, Literal[3]] = np.array([1, 2, 3], dtype=np.uint16) |
| 783 | + validate_full_roundtrip(value_u16, Vector[np.uint16, Literal[3]]) |
| 784 | + value_u32: Vector[np.uint32, Literal[3]] = np.array([1, 2, 3], dtype=np.uint32) |
| 785 | + validate_full_roundtrip(value_u32, Vector[np.uint32, Literal[3]]) |
| 786 | + value_u64: Vector[np.uint64, Literal[3]] = np.array([1, 2, 3], dtype=np.uint64) |
| 787 | + with pytest.raises(ValueError, match="type unsupported yet"): |
| 788 | + validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]]) |
| 789 | + |
| 790 | + |
| 791 | +def test_roundtrip_vector_no_dimension() -> None: |
| 792 | + """Test full roundtrip for vector types without dimension annotation.""" |
| 793 | + value_f64: Vector[np.float64] = np.array([1.0, 2.0, 3.0], dtype=np.float64) |
| 794 | + validate_full_roundtrip(value_f64, Vector[np.float64]) |
| 795 | + |
| 796 | + |
| 797 | +def test_roundtrip_string_vector() -> None: |
| 798 | + """Test full roundtrip for string vector using list.""" |
| 799 | + value_str: Vector[str] = ["hello", "world"] |
| 800 | + validate_full_roundtrip(value_str, Vector[str]) |
| 801 | + |
| 802 | + |
| 803 | +def test_roundtrip_empty_vector() -> None: |
| 804 | + """Test full roundtrip for empty numeric vector.""" |
| 805 | + value_empty: Vector[np.float32] = np.array([], dtype=np.float32) |
| 806 | + validate_full_roundtrip(value_empty, Vector[np.float32]) |
| 807 | + |
| 808 | + |
| 809 | +def test_roundtrip_dimension_mismatch() -> None: |
| 810 | + """Test that dimension mismatch raises an error during roundtrip.""" |
| 811 | + value_f32: Vector[np.float32, Literal[3]] = np.array([1.0, 2.0], dtype=np.float32) |
| 812 | + with pytest.raises(ValueError, match="Vector dimension mismatch"): |
| 813 | + validate_full_roundtrip(value_f32, Vector[np.float32, Literal[3]]) |
| 814 | + |
| 815 | + |
| 816 | +def test_roundtrip_list_backward_compatibility() -> None: |
| 817 | + """Test full roundtrip for list-based vectors for backward compatibility.""" |
| 818 | + value_list: list[int] = [1, 2, 3] |
| 819 | + validate_full_roundtrip(value_list, list[int]) |
0 commit comments