@@ -91,23 +91,26 @@ def validate_full_roundtrip(
9191 """
9292 from cocoindex import _engine # type: ignore
9393
94+ def eq (a : Any , b : Any ) -> bool :
95+ if isinstance (a , np .ndarray ) and isinstance (b , np .ndarray ):
96+ return np .array_equal (a , b )
97+ return type (a ) == type (b ) and not not (a == b )
98+
9499 encoded_value = encode_engine_value (value )
95100 value_type = value_type or type (value )
96101 encoded_output_type = encode_enriched_type (value_type )["type" ]
97102 value_from_engine = _engine .testutil .seder_roundtrip (
98103 encoded_value , encoded_output_type
99104 )
100- decoded_value = build_engine_value_decoder (value_type , value_type )(
101- value_from_engine
102- )
103- np .testing .assert_array_equal (decoded_value , value )
105+ decoder = make_engine_value_decoder ([], encoded_output_type , value_type )
106+ decoded_value = decoder (value_from_engine )
107+ assert eq (decoded_value , value )
104108
105109 if other_decoded_values is not None :
106110 for other_value , other_type in other_decoded_values :
107- other_decoded_value = build_engine_value_decoder (other_type , other_type )(
108- value_from_engine
109- )
110- np .testing .assert_array_equal (other_decoded_value , other_value )
111+ decoder = make_engine_value_decoder ([], encoded_output_type , other_type )
112+ other_decoded_value = decoder (value_from_engine )
113+ assert eq (other_decoded_value , other_value )
111114
112115
113116def test_encode_engine_value_basic_types () -> None :
@@ -215,19 +218,38 @@ def test_encode_engine_value_none() -> None:
215218
216219
217220def test_roundtrip_basic_types () -> None :
218- validate_full_roundtrip (42 , int )
221+ validate_full_roundtrip (42 , int , ( 42 , None ) )
219222 validate_full_roundtrip (3.25 , float , (3.25 , Float64 ))
220- validate_full_roundtrip (3.25 , Float64 , (3.25 , float ))
221- validate_full_roundtrip (3.25 , Float32 )
222- validate_full_roundtrip ("hello" , str )
223- validate_full_roundtrip (True , bool )
224- validate_full_roundtrip (False , bool )
225- validate_full_roundtrip (datetime .date (2025 , 1 , 1 ), datetime .date )
226- validate_full_roundtrip (datetime .datetime .now (), cocoindex .LocalDateTime )
227223 validate_full_roundtrip (
228- datetime .datetime .now (datetime .UTC ), cocoindex .OffsetDateTime
224+ 3.25 , Float64 , (3.25 , float ), (np .float64 (3.25 ), np .float64 )
225+ )
226+ validate_full_roundtrip (
227+ 3.25 , Float32 , (3.25 , float ), (np .float32 (3.25 ), np .float32 )
228+ )
229+ validate_full_roundtrip ("hello" , str , ("hello" , None ))
230+ validate_full_roundtrip (True , bool , (True , None ))
231+ validate_full_roundtrip (False , bool , (False , None ))
232+ validate_full_roundtrip (
233+ datetime .date (2025 , 1 , 1 ), datetime .date , (datetime .date (2025 , 1 , 1 ), None )
229234 )
230235
236+ validate_full_roundtrip (
237+ datetime .datetime (2025 , 1 , 2 , 3 , 4 , 5 , 123456 ),
238+ cocoindex .LocalDateTime ,
239+ (datetime .datetime (2025 , 1 , 2 , 3 , 4 , 5 , 123456 ), datetime .datetime ),
240+ )
241+ validate_full_roundtrip (
242+ datetime .datetime (2025 , 1 , 2 , 3 , 4 , 5 , 123456 , datetime .UTC ),
243+ cocoindex .OffsetDateTime ,
244+ (
245+ datetime .datetime (2025 , 1 , 2 , 3 , 4 , 5 , 123456 , datetime .UTC ),
246+ datetime .datetime ,
247+ ),
248+ )
249+
250+ uuid_value = uuid .uuid4 ()
251+ validate_full_roundtrip (uuid_value , uuid .UUID , (uuid_value , None ))
252+
231253
232254def test_decode_scalar_numpy_values () -> None :
233255 test_cases = [
@@ -849,37 +871,72 @@ def test_dump_vector_type_annotation_no_dim() -> None:
849871
850872def test_full_roundtrip_vector_numeric_types () -> None :
851873 """Test full roundtrip for numeric vector types using NDArray."""
852- value_f32 : Vector [np .float32 , Literal [3 ]] = np .array (
853- [1.0 , 2.0 , 3.0 ], dtype = np .float32
874+ value_f32 = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
875+ validate_full_roundtrip (
876+ value_f32 ,
877+ Vector [np .float32 , Literal [3 ]],
878+ ([np .float32 (1.0 ), np .float32 (2.0 ), np .float32 (3.0 )], list [np .float32 ]),
879+ ([1.0 , 2.0 , 3.0 ], list [cocoindex .Float32 ]),
880+ ([1.0 , 2.0 , 3.0 ], list [float ]),
881+ )
882+ validate_full_roundtrip (
883+ value_f32 ,
884+ np .typing .NDArray [np .float32 ],
885+ ([np .float32 (1.0 ), np .float32 (2.0 ), np .float32 (3.0 )], list [np .float32 ]),
886+ ([1.0 , 2.0 , 3.0 ], list [cocoindex .Float32 ]),
887+ ([1.0 , 2.0 , 3.0 ], list [float ]),
888+ )
889+ validate_full_roundtrip (
890+ value_f32 .tolist (),
891+ list [np .float32 ],
892+ (value_f32 , Vector [np .float32 , Literal [3 ]]),
893+ ([1.0 , 2.0 , 3.0 ], list [cocoindex .Float32 ]),
894+ ([1.0 , 2.0 , 3.0 ], list [float ]),
895+ )
896+
897+ value_f64 = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float64 )
898+ validate_full_roundtrip (
899+ value_f64 ,
900+ Vector [np .float64 , Literal [3 ]],
901+ ([np .float64 (1.0 ), np .float64 (2.0 ), np .float64 (3.0 )], list [np .float64 ]),
902+ ([1.0 , 2.0 , 3.0 ], list [cocoindex .Float64 ]),
903+ ([1.0 , 2.0 , 3.0 ], list [float ]),
854904 )
855- validate_full_roundtrip (value_f32 , Vector [np .float32 , Literal [3 ]])
856- value_f64 : Vector [np .float64 , Literal [3 ]] = np .array (
857- [1.0 , 2.0 , 3.0 ], dtype = np .float64
905+
906+ value_i64 = np .array ([1 , 2 , 3 ], dtype = np .int64 )
907+ validate_full_roundtrip (
908+ value_i64 ,
909+ Vector [np .int64 , Literal [3 ]],
910+ ([np .int64 (1 ), np .int64 (2 ), np .int64 (3 )], list [np .int64 ]),
911+ ([1 , 2 , 3 ], list [int ]),
858912 )
859- validate_full_roundtrip (value_f64 , Vector [np .float64 , Literal [3 ]])
860- value_i64 : Vector [np .int64 , Literal [3 ]] = np .array ([1 , 2 , 3 ], dtype = np .int64 )
861- validate_full_roundtrip (value_i64 , Vector [np .int64 , Literal [3 ]])
862- value_i32 : Vector [np .int32 , Literal [3 ]] = np .array ([1 , 2 , 3 ], dtype = np .int32 )
913+
914+ value_i32 = np .array ([1 , 2 , 3 ], dtype = np .int32 )
863915 with pytest .raises (ValueError , match = "Unsupported NumPy dtype" ):
864916 validate_full_roundtrip (value_i32 , Vector [np .int32 , Literal [3 ]])
865- value_u8 : Vector [ np . uint8 , Literal [ 3 ]] = np .array ([1 , 2 , 3 ], dtype = np .uint8 )
917+ value_u8 = np .array ([1 , 2 , 3 ], dtype = np .uint8 )
866918 with pytest .raises (ValueError , match = "Unsupported NumPy dtype" ):
867919 validate_full_roundtrip (value_u8 , Vector [np .uint8 , Literal [3 ]])
868- value_u16 : Vector [ np . uint16 , Literal [ 3 ]] = np .array ([1 , 2 , 3 ], dtype = np .uint16 )
920+ value_u16 = np .array ([1 , 2 , 3 ], dtype = np .uint16 )
869921 with pytest .raises (ValueError , match = "Unsupported NumPy dtype" ):
870922 validate_full_roundtrip (value_u16 , Vector [np .uint16 , Literal [3 ]])
871- value_u32 : Vector [ np . uint32 , Literal [ 3 ]] = np .array ([1 , 2 , 3 ], dtype = np .uint32 )
923+ value_u32 = np .array ([1 , 2 , 3 ], dtype = np .uint32 )
872924 with pytest .raises (ValueError , match = "Unsupported NumPy dtype" ):
873925 validate_full_roundtrip (value_u32 , Vector [np .uint32 , Literal [3 ]])
874- value_u64 : Vector [ np . uint64 , Literal [ 3 ]] = np .array ([1 , 2 , 3 ], dtype = np .uint64 )
926+ value_u64 = np .array ([1 , 2 , 3 ], dtype = np .uint64 )
875927 with pytest .raises (ValueError , match = "Unsupported NumPy dtype" ):
876928 validate_full_roundtrip (value_u64 , Vector [np .uint64 , Literal [3 ]])
877929
878930
879931def test_roundtrip_vector_no_dimension () -> None :
880932 """Test full roundtrip for vector types without dimension annotation."""
881- value_f64 : Vector [np .float64 ] = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float64 )
882- validate_full_roundtrip (value_f64 , Vector [np .float64 ])
933+ value_f64 = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float64 )
934+ validate_full_roundtrip (
935+ value_f64 ,
936+ Vector [np .float64 ],
937+ ([1.0 , 2.0 , 3.0 ], list [float ]),
938+ (np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float64 ), np .typing .NDArray [np .float64 ]),
939+ )
883940
884941
885942def test_roundtrip_string_vector () -> None :
@@ -904,9 +961,9 @@ def test_roundtrip_dimension_mismatch() -> None:
904961def test_full_roundtrip_scalar_numeric_types () -> None :
905962 """Test full roundtrip for scalar NumPy numeric types."""
906963 # Test supported scalar types
907- validate_full_roundtrip (np .int64 (42 ), np .int64 )
908- validate_full_roundtrip (np .float32 (3.14 ), np .float32 )
909- validate_full_roundtrip (np .float64 (2.718 ), np .float64 )
964+ validate_full_roundtrip (np .int64 (42 ), np .int64 , ( 42 , int ) )
965+ validate_full_roundtrip (np .float32 (3.25 ), np .float32 , ( 3.25 , cocoindex . Float32 ) )
966+ validate_full_roundtrip (np .float64 (3.25 ), np .float64 , ( 3.25 , cocoindex . Float64 ) )
910967
911968 # Test unsupported scalar types
912969 for unsupported_type in [np .int32 , np .uint8 , np .uint16 , np .uint32 , np .uint64 ]:
0 commit comments