11import dataclasses
22import uuid
33import datetime
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , make_dataclass
55import pytest
6+ from cocoindex .typing import encode_enriched_type
67from cocoindex .convert import to_engine_value
8+ from cocoindex .convert import make_engine_value_converter
79
810@dataclass
911class Order :
1012 order_id : str
1113 name : str
1214 price : float
15+ extra_field : str = "default_extra"
16+
17+ @dataclass
18+ class Tag :
19+ name : str
1320
1421@dataclass
1522class Basket :
@@ -19,6 +26,21 @@ class Basket:
1926class Customer :
2027 name : str
2128 order : Order
29+ tags : list [Tag ] = None
30+
31+ @dataclass
32+ class NestedStruct :
33+ customer : Customer
34+ orders : list [Order ]
35+ count : int = 0
36+
37+ def build_engine_value_converter (engine_type_in_py , python_type = None ):
38+ """
39+ Helper to build a converter for the given engine-side type (as represented in Python).
40+ If python_type is not specified, uses engine_type_in_py as the target.
41+ """
42+ engine_type = encode_enriched_type (engine_type_in_py )["type" ]
43+ return make_engine_value_converter ([], engine_type , python_type or engine_type_in_py )
2244
2345def test_to_engine_value_basic_types ():
2446 assert to_engine_value (123 ) == 123
@@ -40,19 +62,19 @@ def test_to_engine_value_date_time_types():
4062
4163def test_to_engine_value_struct ():
4264 order = Order (order_id = "O123" , name = "mixed nuts" , price = 25.0 )
43- assert to_engine_value (order ) == ["O123" , "mixed nuts" , 25.0 ]
65+ assert to_engine_value (order ) == ["O123" , "mixed nuts" , 25.0 , "default_extra" ]
4466
4567def test_to_engine_value_list_of_structs ():
4668 orders = [Order ("O1" , "item1" , 10.0 ), Order ("O2" , "item2" , 20.0 )]
47- assert to_engine_value (orders ) == [["O1" , "item1" , 10.0 ], ["O2" , "item2" , 20.0 ]]
69+ assert to_engine_value (orders ) == [["O1" , "item1" , 10.0 , "default_extra" ], ["O2" , "item2" , 20.0 , "default_extra" ]]
4870
4971def test_to_engine_value_struct_with_list ():
5072 basket = Basket (items = ["apple" , "banana" ])
5173 assert to_engine_value (basket ) == [["apple" , "banana" ]]
5274
5375def test_to_engine_value_nested_struct ():
5476 customer = Customer (name = "Alice" , order = Order ("O1" , "item1" , 10.0 ))
55- assert to_engine_value (customer ) == ["Alice" , ["O1" , "item1" , 10.0 ] ]
77+ assert to_engine_value (customer ) == ["Alice" , ["O1" , "item1" , 10.0 , "default_extra" ], None ]
5678
5779def test_to_engine_value_empty_list ():
5880 assert to_engine_value ([]) == []
@@ -67,3 +89,146 @@ def test_to_engine_value_tuple():
6789
6890def test_to_engine_value_none ():
6991 assert to_engine_value (None ) is None
92+
93+ def test_make_engine_value_converter_basic_types ():
94+ for engine_type_in_py , value in [
95+ (int , 42 ),
96+ (float , 3.14 ),
97+ (str , "hello" ),
98+ (bool , True ),
99+ # (type(None), None), # Removed unsupported NoneType
100+ ]:
101+ converter = build_engine_value_converter (engine_type_in_py )
102+ assert converter (value ) == value
103+
104+ @pytest .mark .parametrize (
105+ "converter_type, engine_val, expected" ,
106+ [
107+ # All fields match
108+ (Order , ["O123" , "mixed nuts" , 25.0 , "default_extra" ], Order ("O123" , "mixed nuts" , 25.0 , "default_extra" )),
109+ # Extra field in engine value (should ignore extra)
110+ (Order , ["O123" , "mixed nuts" , 25.0 , "default_extra" , "unexpected" ], Order ("O123" , "mixed nuts" , 25.0 , "default_extra" )),
111+ # Fewer fields in engine value (should fill with default)
112+ (Order , ["O123" , "mixed nuts" , 0.0 , "default_extra" ], Order ("O123" , "mixed nuts" , 0.0 , "default_extra" )),
113+ # More fields in engine value (should ignore extra)
114+ (Order , ["O123" , "mixed nuts" , 25.0 , "unexpected" ], Order ("O123" , "mixed nuts" , 25.0 , "unexpected" )),
115+ # Truly extra field (should ignore the fifth field)
116+ (Order , ["O123" , "mixed nuts" , 25.0 , "default_extra" , "ignored" ], Order ("O123" , "mixed nuts" , 25.0 , "default_extra" )),
117+ # Missing optional field in engine value (tags=None)
118+ (Customer , ["Alice" , ["O1" , "item1" , 10.0 , "default_extra" ], None ], Customer ("Alice" , Order ("O1" , "item1" , 10.0 , "default_extra" ), None )),
119+ # Extra field in engine value for Customer (should ignore)
120+ (Customer , ["Alice" , ["O1" , "item1" , 10.0 , "default_extra" ], [["vip" ]], "extra" ], Customer ("Alice" , Order ("O1" , "item1" , 10.0 , "default_extra" ), [Tag ("vip" )])),
121+ ]
122+ )
123+ def test_struct_conversion_cases (converter_type , engine_val , expected ):
124+ converter = build_engine_value_converter (converter_type )
125+ assert converter (engine_val ) == expected
126+
127+ def test_make_engine_value_converter_collections ():
128+ # List of structs
129+ converter = build_engine_value_converter (list [Order ])
130+ engine_val = [
131+ ["O1" , "item1" , 10.0 , "default_extra" ],
132+ ["O2" , "item2" , 20.0 , "default_extra" ]
133+ ]
134+ assert converter (engine_val ) == [Order ("O1" , "item1" , 10.0 , "default_extra" ), Order ("O2" , "item2" , 20.0 , "default_extra" )]
135+ # Struct with list field
136+ converter = build_engine_value_converter (Customer )
137+ engine_val = ["Alice" , ["O1" , "item1" , 10.0 , "default_extra" ], [["vip" ], ["premium" ]]]
138+ assert converter (engine_val ) == Customer ("Alice" , Order ("O1" , "item1" , 10.0 , "default_extra" ), [Tag ("vip" ), Tag ("premium" )])
139+ # Struct with struct field
140+ converter = build_engine_value_converter (NestedStruct )
141+ engine_val = [
142+ ["Alice" , ["O1" , "item1" , 10.0 , "default_extra" ], [["vip" ]]],
143+ [["O1" , "item1" , 10.0 , "default_extra" ], ["O2" , "item2" , 20.0 , "default_extra" ]],
144+ 2
145+ ]
146+ assert converter (engine_val ) == NestedStruct (
147+ Customer ("Alice" , Order ("O1" , "item1" , 10.0 , "default_extra" ), [Tag ("vip" )]),
148+ [Order ("O1" , "item1" , 10.0 , "default_extra" ), Order ("O2" , "item2" , 20.0 , "default_extra" )],
149+ 2
150+ )
151+
152+ def make_engine_order (fields ):
153+ return make_dataclass ('EngineOrder' , fields )
154+
155+ def make_python_order (fields , defaults = None ):
156+ if defaults is None :
157+ defaults = {}
158+ # Move all fields with defaults to the end (Python dataclass requirement)
159+ non_default_fields = [(n , t ) for n , t in fields if n not in defaults ]
160+ default_fields = [(n , t ) for n , t in fields if n in defaults ]
161+ ordered_fields = non_default_fields + default_fields
162+ # Prepare the namespace for defaults (only for fields at the end)
163+ namespace = {k : defaults [k ] for k , _ in default_fields }
164+ return make_dataclass ('PythonOrder' , ordered_fields , namespace = namespace )
165+
166+ @pytest .mark .parametrize (
167+ "engine_fields, python_fields, python_defaults, engine_val, expected_python_val" ,
168+ [
169+ # Extra field in Python (middle)
170+ (
171+ [("id" , str ), ("name" , str )],
172+ [("id" , str ), ("price" , float ), ("name" , str )],
173+ {"price" : 0.0 },
174+ ["O123" , "mixed nuts" ],
175+ ("O123" , 0.0 , "mixed nuts" ),
176+ ),
177+ # Missing field in Python (middle)
178+ (
179+ [("id" , str ), ("price" , float ), ("name" , str )],
180+ [("id" , str ), ("name" , str )],
181+ {},
182+ ["O123" , 25.0 , "mixed nuts" ],
183+ ("O123" , "mixed nuts" ),
184+ ),
185+ # Extra field in Python (start)
186+ (
187+ [("name" , str ), ("price" , float )],
188+ [("extra" , str ), ("name" , str ), ("price" , float )],
189+ {"extra" : "default" },
190+ ["mixed nuts" , 25.0 ],
191+ ("default" , "mixed nuts" , 25.0 ),
192+ ),
193+ # Missing field in Python (start)
194+ (
195+ [("extra" , str ), ("name" , str ), ("price" , float )],
196+ [("name" , str ), ("price" , float )],
197+ {},
198+ ["unexpected" , "mixed nuts" , 25.0 ],
199+ ("mixed nuts" , 25.0 ),
200+ ),
201+ # Field order difference (should map by name)
202+ (
203+ [("id" , str ), ("name" , str ), ("price" , float )],
204+ [("name" , str ), ("id" , str ), ("price" , float ), ("extra" , str )],
205+ {"extra" : "default" },
206+ ["O123" , "mixed nuts" , 25.0 ],
207+ ("mixed nuts" , "O123" , 25.0 , "default" ),
208+ ),
209+ # Extra field (Python has extra field with default)
210+ (
211+ [("id" , str ), ("name" , str )],
212+ [("id" , str ), ("name" , str ), ("price" , float )],
213+ {"price" : 0.0 },
214+ ["O123" , "mixed nuts" ],
215+ ("O123" , "mixed nuts" , 0.0 ),
216+ ),
217+ # Missing field (Engine has extra field)
218+ (
219+ [("id" , str ), ("name" , str ), ("price" , float )],
220+ [("id" , str ), ("name" , str )],
221+ {},
222+ ["O123" , "mixed nuts" , 25.0 ],
223+ ("O123" , "mixed nuts" ),
224+ ),
225+ ]
226+ )
227+ def test_field_position_cases (engine_fields , python_fields , python_defaults , engine_val , expected_python_val ):
228+ EngineOrder = make_engine_order (engine_fields )
229+ PythonOrder = make_python_order (python_fields , python_defaults )
230+ converter = build_engine_value_converter (EngineOrder , PythonOrder )
231+ # Map field names to expected values
232+ expected_dict = dict (zip ([f [0 ] for f in python_fields ], expected_python_val ))
233+ # Instantiate using keyword arguments (order doesn't matter)
234+ assert converter (engine_val ) == PythonOrder (** expected_dict )
0 commit comments