11import uuid
22import datetime
33from dataclasses import dataclass , make_dataclass
4- from typing import NamedTuple , Literal
4+ from typing import NamedTuple , Literal , Any , Callable
55import pytest
66import cocoindex
77from cocoindex .typing import encode_enriched_type
@@ -53,7 +53,9 @@ class CustomerNamedTuple(NamedTuple):
5353 tags : list [Tag ] | None = None
5454
5555
56- def build_engine_value_decoder (engine_type_in_py , python_type = None ):
56+ def build_engine_value_decoder (
57+ engine_type_in_py : Any , python_type : Any | None = None
58+ ) -> Callable [[Any ], Any ]:
5759 """
5860 Helper to build a converter for the given engine-side type (as represented in Python).
5961 If python_type is not specified, uses engine_type_in_py as the target.
@@ -62,6 +64,27 @@ def build_engine_value_decoder(engine_type_in_py, python_type=None):
6264 return make_engine_value_decoder ([], engine_type , python_type or engine_type_in_py )
6365
6466
67+ def validate_full_roundtrip (
68+ value : Any , output_type : Any , input_type : Any | None = None
69+ ) -> None :
70+ """
71+ Validate the given value doesn't change after encoding, sending to engine (using output_type), receiving back and decoding (using input_type).
72+
73+ If `input_type` is not specified, uses `output_type` as the target.
74+ """
75+ from cocoindex import _engine
76+
77+ encoded_value = encode_engine_value (value )
78+ encoded_output_type = encode_enriched_type (output_type )["type" ]
79+ value_from_engine = _engine .testutil .seder_roundtrip (
80+ encoded_value , encoded_output_type
81+ )
82+ decoded_value = build_engine_value_decoder (input_type or output_type , output_type )(
83+ value_from_engine
84+ )
85+ assert decoded_value == value
86+
87+
6588def test_encode_engine_value_basic_types ():
6689 assert encode_engine_value (123 ) == 123
6790 assert encode_engine_value (3.14 ) == 3.14
@@ -434,57 +457,33 @@ def test_field_position_cases(
434457 assert decoder (engine_val ) == PythonOrder (** expected_dict )
435458
436459
437- def test_roundtrip_ltable ():
460+ def test_roundtrip_ltable () -> None :
438461 t = list [Order ]
439462 value = [Order ("O1" , "item1" , 10.0 ), Order ("O2" , "item2" , 20.0 )]
440- encoded = encode_engine_value (value )
441- assert encoded == [
442- ["O1" , "item1" , 10.0 , "default_extra" ],
443- ["O2" , "item2" , 20.0 , "default_extra" ],
444- ]
445- decoded = build_engine_value_decoder (t )(encoded )
446- assert decoded == value
463+ validate_full_roundtrip (value , t )
447464
448465 t_nt = list [OrderNamedTuple ]
449466 value_nt = [
450467 OrderNamedTuple ("O1" , "item1" , 10.0 ),
451468 OrderNamedTuple ("O2" , "item2" , 20.0 ),
452469 ]
453- encoded = encode_engine_value (value_nt )
454- assert encoded == [
455- ["O1" , "item1" , 10.0 , "default_extra" ],
456- ["O2" , "item2" , 20.0 , "default_extra" ],
457- ]
458- decoded = build_engine_value_decoder (t_nt )(encoded )
459- assert decoded == value_nt
470+ validate_full_roundtrip (value_nt , t_nt )
460471
461472
462- def test_roundtrip_ktable_str_key ():
473+ def test_roundtrip_ktable_str_key () -> None :
463474 t = dict [str , Order ]
464475 value = {"K1" : Order ("O1" , "item1" , 10.0 ), "K2" : Order ("O2" , "item2" , 20.0 )}
465- encoded = encode_engine_value (value )
466- assert encoded == [
467- ["K1" , "O1" , "item1" , 10.0 , "default_extra" ],
468- ["K2" , "O2" , "item2" , 20.0 , "default_extra" ],
469- ]
470- decoded = build_engine_value_decoder (t )(encoded )
471- assert decoded == value
476+ validate_full_roundtrip (value , t )
472477
473478 t_nt = dict [str , OrderNamedTuple ]
474479 value_nt = {
475480 "K1" : OrderNamedTuple ("O1" , "item1" , 10.0 ),
476481 "K2" : OrderNamedTuple ("O2" , "item2" , 20.0 ),
477482 }
478- encoded = encode_engine_value (value_nt )
479- assert encoded == [
480- ["K1" , "O1" , "item1" , 10.0 , "default_extra" ],
481- ["K2" , "O2" , "item2" , 20.0 , "default_extra" ],
482- ]
483- decoded = build_engine_value_decoder (t_nt )(encoded )
484- assert decoded == value_nt
483+ validate_full_roundtrip (value_nt , t_nt )
485484
486485
487- def test_roundtrip_ktable_struct_key ():
486+ def test_roundtrip_ktable_struct_key () -> None :
488487 @dataclass (frozen = True )
489488 class OrderKey :
490489 shop_id : str
@@ -495,26 +494,14 @@ class OrderKey:
495494 OrderKey ("A" , 3 ): Order ("O1" , "item1" , 10.0 ),
496495 OrderKey ("B" , 4 ): Order ("O2" , "item2" , 20.0 ),
497496 }
498- encoded = encode_engine_value (value )
499- assert encoded == [
500- [["A" , 3 ], "O1" , "item1" , 10.0 , "default_extra" ],
501- [["B" , 4 ], "O2" , "item2" , 20.0 , "default_extra" ],
502- ]
503- decoded = build_engine_value_decoder (t )(encoded )
504- assert decoded == value
497+ validate_full_roundtrip (value , t )
505498
506499 t_nt = dict [OrderKey , OrderNamedTuple ]
507500 value_nt = {
508501 OrderKey ("A" , 3 ): OrderNamedTuple ("O1" , "item1" , 10.0 ),
509502 OrderKey ("B" , 4 ): OrderNamedTuple ("O2" , "item2" , 20.0 ),
510503 }
511- encoded = encode_engine_value (value_nt )
512- assert encoded == [
513- [["A" , 3 ], "O1" , "item1" , 10.0 , "default_extra" ],
514- [["B" , 4 ], "O2" , "item2" , 20.0 , "default_extra" ],
515- ]
516- decoded = build_engine_value_decoder (t_nt )(encoded )
517- assert decoded == value_nt
504+ validate_full_roundtrip (value_nt , t_nt )
518505
519506
520507IntVectorType = cocoindex .Vector [int , Literal [5 ]]
0 commit comments