@@ -133,24 +133,24 @@ def validate_full_roundtrip(
133133
134134
135135def test_encode_engine_value_basic_types () -> None :
136- assert encode_engine_value (123 ) == 123
137- assert encode_engine_value (3.14 ) == 3.14
138- assert encode_engine_value ("hello" ) == "hello"
139- assert encode_engine_value (True ) is True
136+ assert encode_engine_value (123 , int ) == 123
137+ assert encode_engine_value (3.14 , float ) == 3.14
138+ assert encode_engine_value ("hello" , str ) == "hello"
139+ assert encode_engine_value (True , bool ) is True
140140
141141
142142def test_encode_engine_value_uuid () -> None :
143143 u = uuid .uuid4 ()
144- assert encode_engine_value (u ) == u
144+ assert encode_engine_value (u , uuid . UUID ) == u
145145
146146
147147def test_encode_engine_value_date_time_types () -> None :
148148 d = datetime .date (2024 , 1 , 1 )
149- assert encode_engine_value (d ) == d
149+ assert encode_engine_value (d , datetime . date ) == d
150150 t = datetime .time (12 , 30 )
151- assert encode_engine_value (t ) == t
151+ assert encode_engine_value (t , datetime . time ) == t
152152 dt = datetime .datetime (2024 , 1 , 1 , 12 , 30 )
153- assert encode_engine_value (dt ) == dt
153+ assert encode_engine_value (dt , datetime . datetime ) == dt
154154
155155
156156def test_encode_scalar_numpy_values () -> None :
@@ -161,17 +161,22 @@ def test_encode_scalar_numpy_values() -> None:
161161 (np .float64 (2.718 ), pytest .approx (2.718 )),
162162 ]
163163 for np_value , expected in test_cases :
164- encoded = encode_engine_value (np_value )
164+ encoded = encode_engine_value (np_value , type ( np_value ) )
165165 assert encoded == expected
166166 assert isinstance (encoded , (int , float ))
167167
168168
169169def test_encode_engine_value_struct () -> None :
170170 order = Order (order_id = "O123" , name = "mixed nuts" , price = 25.0 )
171- assert encode_engine_value (order ) == ["O123" , "mixed nuts" , 25.0 , "default_extra" ]
171+ assert encode_engine_value (order , Order ) == [
172+ "O123" ,
173+ "mixed nuts" ,
174+ 25.0 ,
175+ "default_extra" ,
176+ ]
172177
173178 order_nt = OrderNamedTuple (order_id = "O123" , name = "mixed nuts" , price = 25.0 )
174- assert encode_engine_value (order_nt ) == [
179+ assert encode_engine_value (order_nt , OrderNamedTuple ) == [
175180 "O123" ,
176181 "mixed nuts" ,
177182 25.0 ,
@@ -181,7 +186,7 @@ def test_encode_engine_value_struct() -> None:
181186
182187def test_encode_engine_value_list_of_structs () -> None :
183188 orders = [Order ("O1" , "item1" , 10.0 ), Order ("O2" , "item2" , 20.0 )]
184- assert encode_engine_value (orders ) == [
189+ assert encode_engine_value (orders , list [ Order ] ) == [
185190 ["O1" , "item1" , 10.0 , "default_extra" ],
186191 ["O2" , "item2" , 20.0 , "default_extra" ],
187192 ]
@@ -190,20 +195,20 @@ def test_encode_engine_value_list_of_structs() -> None:
190195 OrderNamedTuple ("O1" , "item1" , 10.0 ),
191196 OrderNamedTuple ("O2" , "item2" , 20.0 ),
192197 ]
193- assert encode_engine_value (orders_nt ) == [
198+ assert encode_engine_value (orders_nt , list [ OrderNamedTuple ] ) == [
194199 ["O1" , "item1" , 10.0 , "default_extra" ],
195200 ["O2" , "item2" , 20.0 , "default_extra" ],
196201 ]
197202
198203
199204def test_encode_engine_value_struct_with_list () -> None :
200205 basket = Basket (items = ["apple" , "banana" ])
201- assert encode_engine_value (basket ) == [["apple" , "banana" ]]
206+ assert encode_engine_value (basket , Basket ) == [["apple" , "banana" ]]
202207
203208
204209def test_encode_engine_value_nested_struct () -> None :
205210 customer = Customer (name = "Alice" , order = Order ("O1" , "item1" , 10.0 ))
206- assert encode_engine_value (customer ) == [
211+ assert encode_engine_value (customer , Customer ) == [
207212 "Alice" ,
208213 ["O1" , "item1" , 10.0 , "default_extra" ],
209214 None ,
@@ -212,28 +217,28 @@ def test_encode_engine_value_nested_struct() -> None:
212217 customer_nt = CustomerNamedTuple (
213218 name = "Alice" , order = OrderNamedTuple ("O1" , "item1" , 10.0 )
214219 )
215- assert encode_engine_value (customer_nt ) == [
220+ assert encode_engine_value (customer_nt , CustomerNamedTuple ) == [
216221 "Alice" ,
217222 ["O1" , "item1" , 10.0 , "default_extra" ],
218223 None ,
219224 ]
220225
221226
222227def test_encode_engine_value_empty_list () -> None :
223- assert encode_engine_value ([]) == []
224- assert encode_engine_value ([[]]) == [[]]
228+ assert encode_engine_value ([], list ) == []
229+ assert encode_engine_value ([[]], list [ list ] ) == [[]]
225230
226231
227232def test_encode_engine_value_tuple () -> None :
228- assert encode_engine_value (()) == []
229- assert encode_engine_value ((1 , 2 , 3 )) == [1 , 2 , 3 ]
230- assert encode_engine_value (((1 , 2 ), (3 , 4 ))) == [[1 , 2 ], [3 , 4 ]]
231- assert encode_engine_value (([],)) == [[]]
232- assert encode_engine_value (((),)) == [[]]
233+ assert encode_engine_value ((), Any ) == []
234+ assert encode_engine_value ((1 , 2 , 3 ), Any ) == [1 , 2 , 3 ]
235+ assert encode_engine_value (((1 , 2 ), (3 , 4 )), Any ) == [[1 , 2 ], [3 , 4 ]]
236+ assert encode_engine_value (([],), Any ) == [[]]
237+ assert encode_engine_value (((),), Any ) == [[]]
233238
234239
235240def test_encode_engine_value_none () -> None :
236- assert encode_engine_value (None ) is None
241+ assert encode_engine_value (None , int | None ) is None
237242
238243
239244def test_roundtrip_basic_types () -> None :
@@ -743,7 +748,7 @@ class OrderKey:
743748
744749def test_vector_as_vector () -> None :
745750 value = np .array ([1 , 2 , 3 , 4 , 5 ], dtype = np .int64 )
746- encoded = encode_engine_value (value )
751+ encoded = encode_engine_value (value , IntVectorType )
747752 assert np .array_equal (encoded , value )
748753 decoded = build_engine_value_decoder (IntVectorType )(encoded )
749754 assert np .array_equal (decoded , value )
@@ -754,7 +759,7 @@ def test_vector_as_vector() -> None:
754759
755760def test_vector_as_list () -> None :
756761 value : ListIntType = [1 , 2 , 3 , 4 , 5 ]
757- encoded = encode_engine_value (value )
762+ encoded = encode_engine_value (value , ListIntType )
758763 assert encoded == [1 , 2 , 3 , 4 , 5 ]
759764 decoded = build_engine_value_decoder (ListIntType )(encoded )
760765 assert np .array_equal (decoded , value )
@@ -772,13 +777,19 @@ def test_vector_as_list() -> None:
772777def test_encode_engine_value_ndarray () -> None :
773778 """Test encoding NDArray vectors to lists for the Rust engine."""
774779 vec_f32 : Float32VectorType = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
775- assert np .array_equal (encode_engine_value (vec_f32 ), [1.0 , 2.0 , 3.0 ])
780+ assert np .array_equal (
781+ encode_engine_value (vec_f32 , Float32VectorType ), [1.0 , 2.0 , 3.0 ]
782+ )
776783 vec_f64 : Float64VectorType = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float64 )
777- assert np .array_equal (encode_engine_value (vec_f64 ), [1.0 , 2.0 , 3.0 ])
784+ assert np .array_equal (
785+ encode_engine_value (vec_f64 , Float64VectorType ), [1.0 , 2.0 , 3.0 ]
786+ )
778787 vec_i64 : Int64VectorType = np .array ([1 , 2 , 3 ], dtype = np .int64 )
779- assert np .array_equal (encode_engine_value (vec_i64 ), [1 , 2 , 3 ])
788+ assert np .array_equal (encode_engine_value (vec_i64 , Int64VectorType ), [1 , 2 , 3 ])
780789 vec_nd_f32 : NDArrayFloat32Type = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
781- assert np .array_equal (encode_engine_value (vec_nd_f32 ), [1.0 , 2.0 , 3.0 ])
790+ assert np .array_equal (
791+ encode_engine_value (vec_nd_f32 , NDArrayFloat32Type ), [1.0 , 2.0 , 3.0 ]
792+ )
782793
783794
784795def test_make_engine_value_decoder_ndarray () -> None :
@@ -808,21 +819,21 @@ def test_make_engine_value_decoder_ndarray() -> None:
808819def test_roundtrip_ndarray_vector () -> None :
809820 """Test roundtrip encoding and decoding of NDArray vectors."""
810821 value_f32 = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
811- encoded_f32 = encode_engine_value (value_f32 )
822+ encoded_f32 = encode_engine_value (value_f32 , Float32VectorType )
812823 np .array_equal (encoded_f32 , [1.0 , 2.0 , 3.0 ])
813824 decoded_f32 = build_engine_value_decoder (Float32VectorType )(encoded_f32 )
814825 assert isinstance (decoded_f32 , np .ndarray )
815826 assert decoded_f32 .dtype == np .float32
816827 assert np .array_equal (decoded_f32 , value_f32 )
817828 value_i64 = np .array ([1 , 2 , 3 ], dtype = np .int64 )
818- encoded_i64 = encode_engine_value (value_i64 )
829+ encoded_i64 = encode_engine_value (value_i64 , Int64VectorType )
819830 assert np .array_equal (encoded_i64 , [1 , 2 , 3 ])
820831 decoded_i64 = build_engine_value_decoder (Int64VectorType )(encoded_i64 )
821832 assert isinstance (decoded_i64 , np .ndarray )
822833 assert decoded_i64 .dtype == np .int64
823834 assert np .array_equal (decoded_i64 , value_i64 )
824835 value_nd_f64 : NDArrayFloat64Type = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float64 )
825- encoded_nd_f64 = encode_engine_value (value_nd_f64 )
836+ encoded_nd_f64 = encode_engine_value (value_nd_f64 , NDArrayFloat64Type )
826837 assert np .array_equal (encoded_nd_f64 , [1.0 , 2.0 , 3.0 ])
827838 decoded_nd_f64 = build_engine_value_decoder (NDArrayFloat64Type )(encoded_nd_f64 )
828839 assert isinstance (decoded_nd_f64 , np .ndarray )
@@ -833,7 +844,7 @@ def test_roundtrip_ndarray_vector() -> None:
833844def test_ndarray_dimension_mismatch () -> None :
834845 """Test dimension enforcement for Vector with specified dimension."""
835846 value = np .array ([1.0 , 2.0 ], dtype = np .float32 )
836- encoded = encode_engine_value (value )
847+ encoded = encode_engine_value (value , NDArray [ np . float32 ] )
837848 assert np .array_equal (encoded , [1.0 , 2.0 ])
838849 with pytest .raises (ValueError , match = "Vector dimension mismatch" ):
839850 build_engine_value_decoder (Float32VectorType )(encoded )
@@ -842,14 +853,14 @@ def test_ndarray_dimension_mismatch() -> None:
842853def test_list_vector_backward_compatibility () -> None :
843854 """Test that list-based vectors still work for backward compatibility."""
844855 value = [1 , 2 , 3 , 4 , 5 ]
845- encoded = encode_engine_value (value )
856+ encoded = encode_engine_value (value , list [ int ] )
846857 assert encoded == [1 , 2 , 3 , 4 , 5 ]
847858 decoded = build_engine_value_decoder (IntVectorType )(encoded )
848859 assert isinstance (decoded , np .ndarray )
849860 assert decoded .dtype == np .int64
850861 assert np .array_equal (decoded , np .array ([1 , 2 , 3 , 4 , 5 ], dtype = np .int64 ))
851862 value_list : ListIntType = [1 , 2 , 3 , 4 , 5 ]
852- encoded = encode_engine_value (value_list )
863+ encoded = encode_engine_value (value_list , ListIntType )
853864 assert np .array_equal (encoded , [1 , 2 , 3 , 4 , 5 ])
854865 decoded = build_engine_value_decoder (ListIntType )(encoded )
855866 assert np .array_equal (decoded , [1 , 2 , 3 , 4 , 5 ])
@@ -867,7 +878,7 @@ class MyStructWithNDArray:
867878 original = MyStructWithNDArray (
868879 name = "test_np" , data = np .array ([1.0 , 0.5 ], dtype = np .float32 ), value = 100
869880 )
870- encoded = encode_engine_value (original )
881+ encoded = encode_engine_value (original , MyStructWithNDArray )
871882
872883 assert encoded [0 ] == original .name
873884 assert np .array_equal (encoded [1 ], original .data )
@@ -1026,7 +1037,7 @@ def test_full_roundtrip_vector_of_vector() -> None:
10261037 ),
10271038 (
10281039 value_f32 ,
1029- np .typing .NDArray [np .typing . NDArray [ np . float32 ] ],
1040+ np .typing .NDArray [np .float32 ],
10301041 ),
10311042 )
10321043
0 commit comments