@@ -34,7 +34,7 @@ class Tag:
3434
3535@dataclass
3636class Basket :
37- items : list
37+ items : list [ str ]
3838
3939
4040@dataclass
@@ -86,7 +86,7 @@ def validate_full_roundtrip(
8686 `other_decoded_values` is a tuple of (value, type) pairs.
8787 If provided, also validate the value can be decoded to the other types.
8888 """
89- from cocoindex import _engine
89+ from cocoindex import _engine # type: ignore
9090
9191 encoded_value = encode_engine_value (value )
9292 value_type = value_type or type (value )
@@ -107,19 +107,19 @@ def validate_full_roundtrip(
107107 np .testing .assert_array_equal (other_decoded_value , other_value )
108108
109109
110- def test_encode_engine_value_basic_types ():
110+ def test_encode_engine_value_basic_types () -> None :
111111 assert encode_engine_value (123 ) == 123
112112 assert encode_engine_value (3.14 ) == 3.14
113113 assert encode_engine_value ("hello" ) == "hello"
114114 assert encode_engine_value (True ) is True
115115
116116
117- def test_encode_engine_value_uuid ():
117+ def test_encode_engine_value_uuid () -> None :
118118 u = uuid .uuid4 ()
119119 assert encode_engine_value (u ) == u .bytes
120120
121121
122- def test_encode_engine_value_date_time_types ():
122+ def test_encode_engine_value_date_time_types () -> None :
123123 d = datetime .date (2024 , 1 , 1 )
124124 assert encode_engine_value (d ) == d
125125 t = datetime .time (12 , 30 )
@@ -128,7 +128,7 @@ def test_encode_engine_value_date_time_types():
128128 assert encode_engine_value (dt ) == dt
129129
130130
131- def test_encode_engine_value_struct ():
131+ def test_encode_engine_value_struct () -> None :
132132 order = Order (order_id = "O123" , name = "mixed nuts" , price = 25.0 )
133133 assert encode_engine_value (order ) == ["O123" , "mixed nuts" , 25.0 , "default_extra" ]
134134
@@ -141,7 +141,7 @@ def test_encode_engine_value_struct():
141141 ]
142142
143143
144- def test_encode_engine_value_list_of_structs ():
144+ def test_encode_engine_value_list_of_structs () -> None :
145145 orders = [Order ("O1" , "item1" , 10.0 ), Order ("O2" , "item2" , 20.0 )]
146146 assert encode_engine_value (orders ) == [
147147 ["O1" , "item1" , 10.0 , "default_extra" ],
@@ -158,12 +158,12 @@ def test_encode_engine_value_list_of_structs():
158158 ]
159159
160160
161- def test_encode_engine_value_struct_with_list ():
161+ def test_encode_engine_value_struct_with_list () -> None :
162162 basket = Basket (items = ["apple" , "banana" ])
163163 assert encode_engine_value (basket ) == [["apple" , "banana" ]]
164164
165165
166- def test_encode_engine_value_nested_struct ():
166+ def test_encode_engine_value_nested_struct () -> None :
167167 customer = Customer (name = "Alice" , order = Order ("O1" , "item1" , 10.0 ))
168168 assert encode_engine_value (customer ) == [
169169 "Alice" ,
@@ -181,20 +181,20 @@ def test_encode_engine_value_nested_struct():
181181 ]
182182
183183
184- def test_encode_engine_value_empty_list ():
184+ def test_encode_engine_value_empty_list () -> None :
185185 assert encode_engine_value ([]) == []
186186 assert encode_engine_value ([[]]) == [[]]
187187
188188
189- def test_encode_engine_value_tuple ():
189+ def test_encode_engine_value_tuple () -> None :
190190 assert encode_engine_value (()) == []
191191 assert encode_engine_value ((1 , 2 , 3 )) == [1 , 2 , 3 ]
192192 assert encode_engine_value (((1 , 2 ), (3 , 4 ))) == [[1 , 2 ], [3 , 4 ]]
193193 assert encode_engine_value (([],)) == [[]]
194194 assert encode_engine_value (((),)) == [[]]
195195
196196
197- def test_encode_engine_value_none ():
197+ def test_encode_engine_value_none () -> None :
198198 assert encode_engine_value (None ) is None
199199
200200
@@ -323,18 +323,18 @@ def test_make_engine_value_decoder_basic_types() -> None:
323323 ),
324324 ],
325325)
326- def test_struct_decoder_cases (data_type , engine_val , expected ) :
326+ def test_struct_decoder_cases (data_type : Any , engine_val : Any , expected : Any ) -> None :
327327 decoder = build_engine_value_decoder (data_type )
328328 assert decoder (engine_val ) == expected
329329
330330
331- def test_make_engine_value_decoder_collections () :
331+ def test_make_engine_value_decoder_list_of_struct () -> None :
332332 # List of structs (dataclass)
333- decoder = build_engine_value_decoder (list [Order ])
334333 engine_val = [
335334 ["O1" , "item1" , 10.0 , "default_extra" ],
336335 ["O2" , "item2" , 20.0 , "default_extra" ],
337336 ]
337+ decoder = build_engine_value_decoder (list [Order ])
338338 assert decoder (engine_val ) == [
339339 Order ("O1" , "item1" , 10.0 , "default_extra" ),
340340 Order ("O2" , "item2" , 20.0 , "default_extra" ),
@@ -347,13 +347,15 @@ def test_make_engine_value_decoder_collections():
347347 OrderNamedTuple ("O2" , "item2" , 20.0 , "default_extra" ),
348348 ]
349349
350+
351+ def test_make_engine_value_decoder_struct_of_list () -> None :
350352 # Struct with list field
351- decoder = build_engine_value_decoder (Customer )
352353 engine_val = [
353354 "Alice" ,
354355 ["O1" , "item1" , 10.0 , "default_extra" ],
355356 [["vip" ], ["premium" ]],
356357 ]
358+ decoder = build_engine_value_decoder (Customer )
357359 assert decoder (engine_val ) == Customer (
358360 "Alice" ,
359361 Order ("O1" , "item1" , 10.0 , "default_extra" ),
@@ -368,8 +370,9 @@ def test_make_engine_value_decoder_collections():
368370 [Tag ("vip" ), Tag ("premium" )],
369371 )
370372
373+
374+ def test_make_engine_value_decoder_struct_of_struct () -> None :
371375 # Struct with struct field
372- decoder = build_engine_value_decoder (NestedStruct )
373376 engine_val = [
374377 ["Alice" , ["O1" , "item1" , 10.0 , "default_extra" ], [["vip" ]]],
375378 [
@@ -378,6 +381,7 @@ def test_make_engine_value_decoder_collections():
378381 ],
379382 2 ,
380383 ]
384+ decoder = build_engine_value_decoder (NestedStruct )
381385 assert decoder (engine_val ) == NestedStruct (
382386 Customer ("Alice" , Order ("O1" , "item1" , 10.0 , "default_extra" ), [Tag ("vip" )]),
383387 [
@@ -388,11 +392,13 @@ def test_make_engine_value_decoder_collections():
388392 )
389393
390394
391- def make_engine_order (fields ) :
395+ def make_engine_order (fields : list [ tuple [ str , type ]]) -> type :
392396 return make_dataclass ("EngineOrder" , fields )
393397
394398
395- def make_python_order (fields , defaults = None ):
399+ def make_python_order (
400+ fields : list [tuple [str , type ]], defaults : dict [str , Any ] | None = None
401+ ) -> type :
396402 if defaults is None :
397403 defaults = {}
398404 # Move all fields with defaults to the end (Python dataclass requirement)
@@ -466,8 +472,12 @@ def make_python_order(fields, defaults=None):
466472 ],
467473)
468474def test_field_position_cases (
469- engine_fields , python_fields , python_defaults , engine_val , expected_python_val
470- ):
475+ engine_fields : list [tuple [str , type ]],
476+ python_fields : list [tuple [str , type ]],
477+ python_defaults : dict [str , Any ],
478+ engine_val : list [Any ],
479+ expected_python_val : tuple [Any , ...],
480+ ) -> None :
471481 EngineOrder = make_engine_order (engine_fields )
472482 PythonOrder = make_python_order (python_fields , python_defaults )
473483 decoder = build_engine_value_decoder (EngineOrder , PythonOrder )
@@ -528,9 +538,9 @@ class OrderKey:
528538
529539
530540def test_vector_as_vector () -> None :
531- value : IntVectorType = [1 , 2 , 3 , 4 , 5 ]
541+ value = np . array ( [1 , 2 , 3 , 4 , 5 ], dtype = np . int64 )
532542 encoded = encode_engine_value (value )
533- assert encoded == [ 1 , 2 , 3 , 4 , 5 ]
543+ assert np . array_equal ( encoded , value )
534544 decoded = build_engine_value_decoder (IntVectorType )(encoded )
535545 assert np .array_equal (decoded , value )
536546
@@ -561,7 +571,7 @@ def test_vector_as_list() -> None:
561571NDArrayInt64Type = NDArray [np .int64 ]
562572
563573
564- def test_encode_engine_value_ndarray ():
574+ def test_encode_engine_value_ndarray () -> None :
565575 """Test encoding NDArray vectors to lists for the Rust engine."""
566576 vec_f32 : Float32VectorType = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
567577 assert np .array_equal (encode_engine_value (vec_f32 ), [1.0 , 2.0 , 3.0 ])
@@ -573,7 +583,7 @@ def test_encode_engine_value_ndarray():
573583 assert np .array_equal (encode_engine_value (vec_nd_f32 ), [1.0 , 2.0 , 3.0 ])
574584
575585
576- def test_make_engine_value_decoder_ndarray ():
586+ def test_make_engine_value_decoder_ndarray () -> None :
577587 """Test decoding engine lists to NDArray vectors."""
578588 decoder_f32 = build_engine_value_decoder (Float32VectorType )
579589 result_f32 = decoder_f32 ([1.0 , 2.0 , 3.0 ])
@@ -597,16 +607,16 @@ def test_make_engine_value_decoder_ndarray():
597607 assert np .array_equal (result_nd_f32 , np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 ))
598608
599609
600- def test_roundtrip_ndarray_vector ():
610+ def test_roundtrip_ndarray_vector () -> None :
601611 """Test roundtrip encoding and decoding of NDArray vectors."""
602- value_f32 : Float32VectorType = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
612+ value_f32 = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
603613 encoded_f32 = encode_engine_value (value_f32 )
604614 np .array_equal (encoded_f32 , [1.0 , 2.0 , 3.0 ])
605615 decoded_f32 = build_engine_value_decoder (Float32VectorType )(encoded_f32 )
606616 assert isinstance (decoded_f32 , np .ndarray )
607617 assert decoded_f32 .dtype == np .float32
608618 assert np .array_equal (decoded_f32 , value_f32 )
609- value_i64 : Int64VectorType = np .array ([1 , 2 , 3 ], dtype = np .int64 )
619+ value_i64 = np .array ([1 , 2 , 3 ], dtype = np .int64 )
610620 encoded_i64 = encode_engine_value (value_i64 )
611621 assert np .array_equal (encoded_i64 , [1 , 2 , 3 ])
612622 decoded_i64 = build_engine_value_decoder (Int64VectorType )(encoded_i64 )
@@ -622,18 +632,18 @@ def test_roundtrip_ndarray_vector():
622632 assert np .array_equal (decoded_nd_f64 , value_nd_f64 )
623633
624634
625- def test_ndarray_dimension_mismatch ():
635+ def test_ndarray_dimension_mismatch () -> None :
626636 """Test dimension enforcement for Vector with specified dimension."""
627- value : Float32VectorType = np .array ([1.0 , 2.0 ], dtype = np .float32 )
637+ value = np .array ([1.0 , 2.0 ], dtype = np .float32 )
628638 encoded = encode_engine_value (value )
629639 assert np .array_equal (encoded , [1.0 , 2.0 ])
630640 with pytest .raises (ValueError , match = "Vector dimension mismatch" ):
631641 build_engine_value_decoder (Float32VectorType )(encoded )
632642
633643
634- def test_list_vector_backward_compatibility ():
644+ def test_list_vector_backward_compatibility () -> None :
635645 """Test that list-based vectors still work for backward compatibility."""
636- value : IntVectorType = [1 , 2 , 3 , 4 , 5 ]
646+ value = [1 , 2 , 3 , 4 , 5 ]
637647 encoded = encode_engine_value (value )
638648 assert encoded == [1 , 2 , 3 , 4 , 5 ]
639649 decoded = build_engine_value_decoder (IntVectorType )(encoded )
@@ -647,7 +657,7 @@ def test_list_vector_backward_compatibility():
647657 assert np .array_equal (decoded , [1 , 2 , 3 , 4 , 5 ])
648658
649659
650- def test_encode_complex_structure_with_ndarray ():
660+ def test_encode_complex_structure_with_ndarray () -> None :
651661 """Test encoding a complex structure that includes an NDArray."""
652662
653663 @dataclass
@@ -660,17 +670,13 @@ class MyStructWithNDArray:
660670 name = "test_np" , data = np .array ([1.0 , 0.5 ], dtype = np .float32 ), value = 100
661671 )
662672 encoded = encode_engine_value (original )
663- expected = [
664- "test_np" ,
665- [1.0 , 0.5 ],
666- 100 ,
667- ]
668- assert encoded [0 ] == expected [0 ]
669- assert np .array_equal (encoded [1 ], expected [1 ])
670- assert encoded [2 ] == expected [2 ]
673+
674+ assert encoded [0 ] == original .name
675+ assert np .array_equal (encoded [1 ], original .data )
676+ assert encoded [2 ] == original .value
671677
672678
673- def test_decode_nullable_ndarray_none_or_value_input ():
679+ def test_decode_nullable_ndarray_none_or_value_input () -> None :
674680 """Test decoding a nullable NDArray with None or value inputs."""
675681 src_type_dict = {
676682 "kind" : "Vector" ,
@@ -694,7 +700,7 @@ def test_decode_nullable_ndarray_none_or_value_input():
694700 )
695701
696702
697- def test_decode_vector_string ():
703+ def test_decode_vector_string () -> None :
698704 """Test decoding a vector of strings works for Python native list type."""
699705 src_type_dict = {
700706 "kind" : "Vector" ,
@@ -705,7 +711,7 @@ def test_decode_vector_string():
705711 assert decoder (["hello" , "world" ]) == ["hello" , "world" ]
706712
707713
708- def test_decode_error_non_nullable_or_non_list_vector ():
714+ def test_decode_error_non_nullable_or_non_list_vector () -> None :
709715 """Test decoding errors for non-nullable vectors or non-list inputs."""
710716 src_type_dict = {
711717 "kind" : "Vector" ,
@@ -719,7 +725,7 @@ def test_decode_error_non_nullable_or_non_list_vector():
719725 decoder ("not a list" )
720726
721727
722- def test_dump_vector_type_annotation_with_dim ():
728+ def test_dump_vector_type_annotation_with_dim () -> None :
723729 """Test dumping a vector type annotation with a specified dimension."""
724730 expected_dump = {
725731 "type" : {
@@ -731,7 +737,7 @@ def test_dump_vector_type_annotation_with_dim():
731737 assert dump_engine_object (Float32VectorType ) == expected_dump
732738
733739
734- def test_dump_vector_type_annotation_no_dim ():
740+ def test_dump_vector_type_annotation_no_dim () -> None :
735741 """Test dumping a vector type annotation with no dimension."""
736742 expected_dump_no_dim = {
737743 "type" : {
0 commit comments