@@ -561,13 +561,13 @@ def test_vector_as_list() -> None:
561561def test_encode_engine_value_ndarray ():
562562 """Test encoding NDArray vectors to lists for the Rust engine."""
563563 vec_f32 : Float32VectorType = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
564- assert encode_engine_value (vec_f32 ) == [1.0 , 2.0 , 3.0 ]
564+ assert np . array_equal ( encode_engine_value (vec_f32 ), [1.0 , 2.0 , 3.0 ])
565565 vec_f64 : Float64VectorType = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float64 )
566- assert encode_engine_value (vec_f64 ) == [1.0 , 2.0 , 3.0 ]
566+ assert np . array_equal ( encode_engine_value (vec_f64 ), [1.0 , 2.0 , 3.0 ])
567567 vec_i64 : Int64VectorType = np .array ([1 , 2 , 3 ], dtype = np .int64 )
568- assert encode_engine_value (vec_i64 ) == [1 , 2 , 3 ]
568+ assert np . array_equal ( encode_engine_value (vec_i64 ), [1 , 2 , 3 ])
569569 vec_nd_f32 : NDArrayFloat32Type = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
570- assert encode_engine_value (vec_nd_f32 ) == [1.0 , 2.0 , 3.0 ]
570+ assert np . array_equal ( encode_engine_value (vec_nd_f32 ), [1.0 , 2.0 , 3.0 ])
571571
572572
573573def test_make_engine_value_decoder_ndarray ():
@@ -598,21 +598,21 @@ def test_roundtrip_ndarray_vector():
598598 """Test roundtrip encoding and decoding of NDArray vectors."""
599599 value_f32 : Float32VectorType = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
600600 encoded_f32 = encode_engine_value (value_f32 )
601- assert encoded_f32 == [1.0 , 2.0 , 3.0 ]
601+ np . array_equal ( encoded_f32 , [1.0 , 2.0 , 3.0 ])
602602 decoded_f32 = build_engine_value_decoder (Float32VectorType )(encoded_f32 )
603603 assert isinstance (decoded_f32 , np .ndarray )
604604 assert decoded_f32 .dtype == np .float32
605605 assert np .array_equal (decoded_f32 , value_f32 )
606606 value_i64 : Int64VectorType = np .array ([1 , 2 , 3 ], dtype = np .int64 )
607607 encoded_i64 = encode_engine_value (value_i64 )
608- assert encoded_i64 == [1 , 2 , 3 ]
608+ assert np . array_equal ( encoded_i64 , [1 , 2 , 3 ])
609609 decoded_i64 = build_engine_value_decoder (Int64VectorType )(encoded_i64 )
610610 assert isinstance (decoded_i64 , np .ndarray )
611611 assert decoded_i64 .dtype == np .int64
612612 assert np .array_equal (decoded_i64 , value_i64 )
613613 value_nd_f64 : NDArrayFloat64Type = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float64 )
614614 encoded_nd_f64 = encode_engine_value (value_nd_f64 )
615- assert encoded_nd_f64 == [1.0 , 2.0 , 3.0 ]
615+ assert np . array_equal ( encoded_nd_f64 , [1.0 , 2.0 , 3.0 ])
616616 decoded_nd_f64 = build_engine_value_decoder (NDArrayFloat64Type )(encoded_nd_f64 )
617617 assert isinstance (decoded_nd_f64 , np .ndarray )
618618 assert decoded_nd_f64 .dtype == np .float64
@@ -623,7 +623,7 @@ def test_uint_support():
623623 """Test encoding and decoding of unsigned integer vectors."""
624624 value_uint8 = np .array ([1 , 2 , 3 , 4 ], dtype = np .uint8 )
625625 encoded = encode_engine_value (value_uint8 )
626- assert encoded == [1 , 2 , 3 , 4 ]
626+ assert np . array_equal ( encoded , [1 , 2 , 3 , 4 ])
627627 decoder = make_engine_value_decoder (
628628 [], {"kind" : "Vector" , "element_type" : {"kind" : "UInt8" }}, NDArray [np .uint8 ]
629629 )
@@ -632,7 +632,7 @@ def test_uint_support():
632632 assert decoded .dtype == np .uint8
633633 value_uint16 = np .array ([1 , 2 , 3 , 4 ], dtype = np .uint16 )
634634 encoded = encode_engine_value (value_uint16 )
635- assert encoded == [1 , 2 , 3 , 4 ]
635+ assert np . array_equal ( encoded , [1 , 2 , 3 , 4 ])
636636 decoder = make_engine_value_decoder (
637637 [], {"kind" : "Vector" , "element_type" : {"kind" : "UInt16" }}, NDArray [np .uint16 ]
638638 )
@@ -641,7 +641,7 @@ def test_uint_support():
641641 assert decoded .dtype == np .uint16
642642 value_uint32 = np .array ([1 , 2 , 3 ], dtype = np .uint32 )
643643 encoded = encode_engine_value (value_uint32 )
644- assert encoded == [1 , 2 , 3 ]
644+ assert np . array_equal ( encoded , [1 , 2 , 3 ])
645645 decoder = make_engine_value_decoder (
646646 [], {"kind" : "Vector" , "element_type" : {"kind" : "UInt32" }}, NDArray [np .uint32 ]
647647 )
@@ -650,7 +650,7 @@ def test_uint_support():
650650 assert decoded .dtype == np .uint32
651651 value_uint64 = np .array ([1 , 2 , 3 ], dtype = np .uint64 )
652652 encoded = encode_engine_value (value_uint64 )
653- assert encoded == [1 , 2 , 3 ]
653+ assert np . array_equal ( encoded , [1 , 2 , 3 ])
654654 decoder = make_engine_value_decoder (
655655 [], {"kind" : "Vector" , "element_type" : {"kind" : "UInt8" }}, NDArray [np .uint64 ]
656656 )
@@ -663,7 +663,7 @@ def test_ndarray_dimension_mismatch():
663663 """Test dimension enforcement for Vector with specified dimension."""
664664 value : Float32VectorType = np .array ([1.0 , 2.0 ], dtype = np .float32 )
665665 encoded = encode_engine_value (value )
666- assert encoded == [1.0 , 2.0 ]
666+ assert np . array_equal ( encoded , [1.0 , 2.0 ])
667667 with pytest .raises (ValueError , match = "Vector dimension mismatch" ):
668668 build_engine_value_decoder (Float32VectorType )(encoded )
669669
@@ -679,9 +679,9 @@ def test_list_vector_backward_compatibility():
679679 assert np .array_equal (decoded , np .array ([1 , 2 , 3 , 4 , 5 ], dtype = np .int64 ))
680680 value_list : ListIntType = [1 , 2 , 3 , 4 , 5 ]
681681 encoded = encode_engine_value (value_list )
682- assert encoded == [1 , 2 , 3 , 4 , 5 ]
682+ assert np . array_equal ( encoded , [1 , 2 , 3 , 4 , 5 ])
683683 decoded = build_engine_value_decoder (ListIntType )(encoded )
684- assert decoded . tolist () == [1 , 2 , 3 , 4 , 5 ]
684+ assert np . array_equal ( decoded , [1 , 2 , 3 , 4 , 5 ])
685685
686686
687687def test_encode_complex_structure_with_ndarray ():
@@ -702,7 +702,9 @@ class MyStructWithNDArray:
702702 [1.0 , 0.5 ],
703703 100 ,
704704 ]
705- assert encoded == expected
705+ assert encoded [0 ] == expected [0 ]
706+ assert np .array_equal (encoded [1 ], expected [1 ])
707+ assert encoded [2 ] == expected [2 ]
706708
707709
708710def test_decode_nullable_ndarray_none_or_value_input ():
@@ -750,7 +752,7 @@ def test_decode_error_non_nullable_or_non_list_vector():
750752 decoder = make_engine_value_decoder ([], src_type_dict , NDArrayFloat32Type )
751753 with pytest .raises (ValueError , match = "Received null for non-nullable vector" ):
752754 decoder (None )
753- with pytest .raises (TypeError , match = "Expected a list for vector" ):
755+ with pytest .raises (TypeError , match = "Expected NDArray or list for vector" ):
754756 decoder ("not a list" )
755757
756758
0 commit comments