11import uuid
22import datetime
33from dataclasses import dataclass , make_dataclass
4- from typing import NamedTuple , Literal
4+ from typing import NamedTuple , Literal , Any , Callable
55import pytest
66import cocoindex
77from cocoindex .typing import encode_enriched_type
@@ -53,7 +53,9 @@ class CustomerNamedTuple(NamedTuple):
5353 tags : list [Tag ] | None = None
5454
5555
56- def build_engine_value_decoder (engine_type_in_py , python_type = None ):
56+ def build_engine_value_decoder (
57+ engine_type_in_py : Any , python_type : Any | None = None
58+ ) -> Callable [[Any ], Any ]:
5759 """
5860 Helper to build a converter for the given engine-side type (as represented in Python).
5961 If python_type is not specified, uses engine_type_in_py as the target.
@@ -62,6 +64,22 @@ def build_engine_value_decoder(engine_type_in_py, python_type=None):
6264 return make_engine_value_decoder ([], engine_type , python_type or engine_type_in_py )
6365
6466
67+ def validate_full_roundtrip (
68+ value : Any , output_type : Any , input_type : Any | None = None
69+ ) -> None :
70+ from cocoindex import _engine
71+
72+ encoded_value = encode_engine_value (value )
73+ encoded_output_type = encode_enriched_type (output_type )["type" ]
74+ value_from_engine = _engine .testutil .seder_roundtrip (
75+ encoded_value , encoded_output_type
76+ )
77+ decoded_value = build_engine_value_decoder (input_type or output_type , output_type )(
78+ value_from_engine
79+ )
80+ assert decoded_value == value
81+
82+
6583def test_encode_engine_value_basic_types ():
6684 assert encode_engine_value (123 ) == 123
6785 assert encode_engine_value (3.14 ) == 3.14
@@ -434,57 +452,33 @@ def test_field_position_cases(
434452 assert decoder (engine_val ) == PythonOrder (** expected_dict )
435453
436454
437- def test_roundtrip_ltable ():
455+ def test_roundtrip_ltable () -> None :
438456 t = list [Order ]
439457 value = [Order ("O1" , "item1" , 10.0 ), Order ("O2" , "item2" , 20.0 )]
440- encoded = encode_engine_value (value )
441- assert encoded == [
442- ["O1" , "item1" , 10.0 , "default_extra" ],
443- ["O2" , "item2" , 20.0 , "default_extra" ],
444- ]
445- decoded = build_engine_value_decoder (t )(encoded )
446- assert decoded == value
458+ validate_full_roundtrip (value , t )
447459
448460 t_nt = list [OrderNamedTuple ]
449461 value_nt = [
450462 OrderNamedTuple ("O1" , "item1" , 10.0 ),
451463 OrderNamedTuple ("O2" , "item2" , 20.0 ),
452464 ]
453- encoded = encode_engine_value (value_nt )
454- assert encoded == [
455- ["O1" , "item1" , 10.0 , "default_extra" ],
456- ["O2" , "item2" , 20.0 , "default_extra" ],
457- ]
458- decoded = build_engine_value_decoder (t_nt )(encoded )
459- assert decoded == value_nt
465+ validate_full_roundtrip (value_nt , t_nt )
460466
461467
462- def test_roundtrip_ktable_str_key ():
468+ def test_roundtrip_ktable_str_key () -> None :
463469 t = dict [str , Order ]
464470 value = {"K1" : Order ("O1" , "item1" , 10.0 ), "K2" : Order ("O2" , "item2" , 20.0 )}
465- encoded = encode_engine_value (value )
466- assert encoded == [
467- ["K1" , "O1" , "item1" , 10.0 , "default_extra" ],
468- ["K2" , "O2" , "item2" , 20.0 , "default_extra" ],
469- ]
470- decoded = build_engine_value_decoder (t )(encoded )
471- assert decoded == value
471+ validate_full_roundtrip (value , t )
472472
473473 t_nt = dict [str , OrderNamedTuple ]
474474 value_nt = {
475475 "K1" : OrderNamedTuple ("O1" , "item1" , 10.0 ),
476476 "K2" : OrderNamedTuple ("O2" , "item2" , 20.0 ),
477477 }
478- encoded = encode_engine_value (value_nt )
479- assert encoded == [
480- ["K1" , "O1" , "item1" , 10.0 , "default_extra" ],
481- ["K2" , "O2" , "item2" , 20.0 , "default_extra" ],
482- ]
483- decoded = build_engine_value_decoder (t_nt )(encoded )
484- assert decoded == value_nt
478+ validate_full_roundtrip (value_nt , t_nt )
485479
486480
487- def test_roundtrip_ktable_struct_key ():
481+ def test_roundtrip_ktable_struct_key () -> None :
488482 @dataclass (frozen = True )
489483 class OrderKey :
490484 shop_id : str
@@ -495,26 +489,14 @@ class OrderKey:
495489 OrderKey ("A" , 3 ): Order ("O1" , "item1" , 10.0 ),
496490 OrderKey ("B" , 4 ): Order ("O2" , "item2" , 20.0 ),
497491 }
498- encoded = encode_engine_value (value )
499- assert encoded == [
500- [["A" , 3 ], "O1" , "item1" , 10.0 , "default_extra" ],
501- [["B" , 4 ], "O2" , "item2" , 20.0 , "default_extra" ],
502- ]
503- decoded = build_engine_value_decoder (t )(encoded )
504- assert decoded == value
492+ validate_full_roundtrip (value , t )
505493
506494 t_nt = dict [OrderKey , OrderNamedTuple ]
507495 value_nt = {
508496 OrderKey ("A" , 3 ): OrderNamedTuple ("O1" , "item1" , 10.0 ),
509497 OrderKey ("B" , 4 ): OrderNamedTuple ("O2" , "item2" , 20.0 ),
510498 }
511- encoded = encode_engine_value (value_nt )
512- assert encoded == [
513- [["A" , 3 ], "O1" , "item1" , 10.0 , "default_extra" ],
514- [["B" , 4 ], "O2" , "item2" , 20.0 , "default_extra" ],
515- ]
516- decoded = build_engine_value_decoder (t_nt )(encoded )
517- assert decoded == value_nt
499+ validate_full_roundtrip (value_nt , t_nt )
518500
519501
520502IntVectorType = cocoindex .Vector [int , Literal [5 ]]
0 commit comments