55from typing import Annotated , Any , Literal , NamedTuple , get_args , get_origin
66
77import numpy as np
8- import pytest
98from numpy .typing import NDArray
109
1110from cocoindex .typing import (
2019 Vector ,
2120 VectorInfo ,
2221 analyze_type_info ,
22+ decode_engine_value_type ,
2323 encode_enriched_type ,
24+ encode_enriched_type_info ,
25+ encode_engine_value_type ,
2426)
2527
2628
@@ -32,7 +34,7 @@ class SimpleDataclass:
3234
3335class SimpleNamedTuple (NamedTuple ):
3436 name : str
35- value : Any
37+ value : int
3638
3739
3840def test_ndarray_float32_no_dim () -> None :
@@ -427,3 +429,125 @@ def test_unknown_type() -> None:
427429 typ = set
428430 result = analyze_type_info (typ )
429431 assert isinstance (result .variant , AnalyzedUnknownType )
432+
433+
434+ # ========================= Encode/Decode Tests =========================
435+
436+
437+ def encode_type_from_annotation (t : Any ) -> dict [str , Any ]:
438+ """Helper function to encode a Python type annotation to its dictionary representation."""
439+ return encode_enriched_type_info (analyze_type_info (t ))
440+
441+
442+ def test_basic_types_encode_decode () -> None :
443+ """Test encode/decode roundtrip for basic Python types."""
444+ test_cases = [
445+ str ,
446+ int ,
447+ float ,
448+ bool ,
449+ bytes ,
450+ uuid .UUID ,
451+ datetime .date ,
452+ datetime .time ,
453+ datetime .datetime ,
454+ datetime .timedelta ,
455+ ]
456+
457+ for typ in test_cases :
458+ encoded = encode_type_from_annotation (typ )
459+ decoded = decode_engine_value_type (encoded ["type" ])
460+ reencoded = encode_engine_value_type (decoded )
461+ assert reencoded == encoded ["type" ]
462+
463+
464+ def test_vector_types_encode_decode () -> None :
465+ """Test encode/decode roundtrip for vector types."""
466+ test_cases = [
467+ NDArray [np .float32 ],
468+ NDArray [np .float64 ],
469+ NDArray [np .int64 ],
470+ Vector [np .float32 ],
471+ Vector [np .float32 , Literal [128 ]],
472+ Vector [str ],
473+ ]
474+
475+ for typ in test_cases :
476+ encoded = encode_type_from_annotation (typ )
477+ decoded = decode_engine_value_type (encoded ["type" ])
478+ reencoded = encode_engine_value_type (decoded )
479+ assert reencoded == encoded ["type" ]
480+
481+
482+ def test_struct_types_encode_decode () -> None :
483+ """Test encode/decode roundtrip for struct types."""
484+ test_cases = [
485+ SimpleDataclass ,
486+ SimpleNamedTuple ,
487+ ]
488+
489+ for typ in test_cases :
490+ encoded = encode_type_from_annotation (typ )
491+ decoded = decode_engine_value_type (encoded ["type" ])
492+ reencoded = encode_engine_value_type (decoded )
493+ assert reencoded == encoded ["type" ]
494+
495+
496+ def test_table_types_encode_decode () -> None :
497+ """Test encode/decode roundtrip for table types."""
498+ test_cases = [
499+ list [SimpleDataclass ], # LTable
500+ dict [str , SimpleDataclass ], # KTable
501+ ]
502+
503+ for typ in test_cases :
504+ encoded = encode_type_from_annotation (typ )
505+ decoded = decode_engine_value_type (encoded ["type" ])
506+ reencoded = encode_engine_value_type (decoded )
507+ assert reencoded == encoded ["type" ]
508+
509+
510+ def test_nullable_types_encode_decode () -> None :
511+ """Test encode/decode roundtrip for nullable types."""
512+ test_cases = [
513+ str | None ,
514+ int | None ,
515+ NDArray [np .float32 ] | None ,
516+ ]
517+
518+ for typ in test_cases :
519+ encoded = encode_type_from_annotation (typ )
520+ decoded = decode_engine_value_type (encoded ["type" ])
521+ reencoded = encode_engine_value_type (decoded )
522+ assert reencoded == encoded ["type" ]
523+
524+
525+ def test_annotated_types_encode_decode () -> None :
526+ """Test encode/decode roundtrip for annotated types."""
527+ test_cases = [
528+ Annotated [str , TypeAttr ("key" , "value" )],
529+ Annotated [NDArray [np .float32 ], VectorInfo (dim = 256 )],
530+ Annotated [list [int ], VectorInfo (dim = 10 )],
531+ ]
532+
533+ for typ in test_cases :
534+ encoded = encode_type_from_annotation (typ )
535+ decoded = decode_engine_value_type (encoded ["type" ])
536+ reencoded = encode_engine_value_type (decoded )
537+ assert reencoded == encoded ["type" ]
538+
539+
540+ def test_complex_nested_encode_decode () -> None :
541+ """Test complex nested structure encode/decode roundtrip."""
542+
543+ # Create a complex nested structure using Python type annotations
544+ @dataclasses .dataclass
545+ class ComplexStruct :
546+ embedding : NDArray [np .float32 ]
547+ metadata : str | None
548+ score : Annotated [float , TypeAttr ("indexed" , True )]
549+
550+ encoded = encode_type_from_annotation (ComplexStruct )
551+ decoded = decode_engine_value_type (encoded ["type" ])
552+ reencoded = encode_engine_value_type (decoded )
553+ assert reencoded == encoded ["type" ]
0 commit comments