@@ -64,106 +64,287 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool
6464 )
6565
6666
67- def _get_type_info_safe (type_to_analyze : Any ) -> AnalyzedTypeInfo :
67+ def _get_type_info (type_to_analyze : Any ) -> AnalyzedTypeInfo :
6868 """Safely get type info, bypassing cache if type is not hashable."""
6969
7070 @functools .cache
71- def _get_cached_type_info () -> AnalyzedTypeInfo :
71+ def _get_cached_type_info (t : Any ) -> AnalyzedTypeInfo :
7272 """cache the computed type information for a given type."""
73- return analyze_type_info (type_to_analyze )
73+ return analyze_type_info (t )
7474
7575 try :
7676 return _get_cached_type_info (type_to_analyze )
77- except TypeError :
77+ except TypeError : # The type is not hashable
7878 return analyze_type_info (type_to_analyze )
7979
8080
81- def _encode_engine_value_core (
82- value : Any ,
83- type_info : AnalyzedTypeInfo | None = None ,
84- ) -> Any :
85- """Core encoding logic for converting Python values to engine values."""
86-
87- if dataclasses .is_dataclass (value ):
88- fields = dataclasses .fields (value )
89- return [
90- _encode_engine_value_core (
91- getattr (value , f .name ),
92- type_info = _get_type_info_safe (f .type ),
81+ def _make_encoder_closure (type_info : AnalyzedTypeInfo | None ) -> Callable [[Any ], Any ]:
82+ """
83+ Create an encoder closure for a specific type.
84+ """
85+ if type_info is None :
86+ # For untyped encoding, fall back to basic logic
87+ def encode_untyped (value : Any ) -> Any :
88+ if dataclasses .is_dataclass (value ):
89+ fields = dataclasses .fields (value )
90+ return [
91+ _make_encoder_closure (_get_type_info (f .type ))(
92+ getattr (value , f .name )
93+ )
94+ for f in fields
95+ ]
96+
97+ if is_namedtuple_type (type (value )):
98+ annotations = type (value ).__annotations__
99+ return [
100+ _make_encoder_closure (
101+ _get_type_info (annotations .get (name ))
102+ if annotations .get (name )
103+ else None
104+ )(getattr (value , name ))
105+ for name in value ._fields
106+ ]
107+
108+ if isinstance (value , np .number ):
109+ return value .item ()
110+
111+ if isinstance (value , np .ndarray ):
112+ return value
113+
114+ if isinstance (value , (list , tuple )):
115+ return [_make_encoder_closure (None )(v ) for v in value ]
116+
117+ if isinstance (value , dict ):
118+ # Handle empty dict
119+ if not value :
120+ return value
121+
122+ # Handle KTable
123+ first_val = next (iter (value .values ()))
124+ if is_struct_type (type (first_val )):
125+ return [
126+ [_make_encoder_closure (None )(k )]
127+ + _make_encoder_closure (None )(v )
128+ for k , v in value .items ()
129+ ]
130+
131+ return value
132+
133+ return encode_untyped
134+
135+ variant = type_info .variant
136+
137+ # Handle JSON types
138+ if isinstance (variant , AnalyzedBasicType ) and variant .kind == "Json" :
139+
140+ def encode_json_dict (value : Any ) -> Any :
141+ if isinstance (value , dict ):
142+ # Handle empty dict
143+ if not value :
144+ return value
145+
146+ # Handle KTable
147+ first_val = next (iter (value .values ()))
148+ if is_struct_type (type (first_val )):
149+ untyped_encoder = _make_encoder_closure (None )
150+ return [
151+ [untyped_encoder (k )] + untyped_encoder (v )
152+ for k , v in value .items ()
153+ ]
154+
155+ return value
156+
157+ return encode_json_dict
158+
159+ # Handle Any types and special numpy cases
160+ if isinstance (variant , AnalyzedAnyType ):
161+
162+ def encode_any_type (value : Any ) -> Any :
163+ # Handle numpy types first
164+ if isinstance (value , np .number ):
165+ return value .item ()
166+ if isinstance (value , np .ndarray ):
167+ return value
168+
169+ # Handle tuples - convert to lists for Any type
170+ if isinstance (value , tuple ):
171+ return [_make_encoder_closure (None )(v ) for v in value ]
172+
173+ # Handle dataclasses
174+ if dataclasses .is_dataclass (value ):
175+ fields = dataclasses .fields (value )
176+ return [
177+ _make_encoder_closure (_get_type_info (f .type ))(
178+ getattr (value , f .name )
179+ )
180+ for f in fields
181+ ]
182+
183+ # Handle namedtuples
184+ if is_namedtuple_type (type (value )):
185+ annotations = type (value ).__annotations__
186+ return [
187+ _make_encoder_closure (
188+ _get_type_info (annotations .get (name ))
189+ if annotations .get (name )
190+ else None
191+ )(getattr (value , name ))
192+ for name in value ._fields
193+ ]
194+
195+ # Handle lists
196+ if isinstance (value , list ):
197+ return [_make_encoder_closure (None )(v ) for v in value ]
198+
199+ # Handle dicts
200+ if isinstance (value , dict ):
201+ # Handle empty dict
202+ if not value :
203+ return value
204+
205+ # Handle KTable
206+ first_val = next (iter (value .values ()))
207+ if is_struct_type (type (first_val )):
208+ return [
209+ [_make_encoder_closure (None )(k )]
210+ + _make_encoder_closure (None )(v )
211+ for k , v in value .items ()
212+ ]
213+
214+ return value
215+
216+ return encode_any_type
217+
218+ # Handle basic types
219+ if isinstance (variant , AnalyzedBasicType ):
220+
221+ def encode_basic_with_numpy (value : Any ) -> Any :
222+ # Handle numpy types for basic types
223+ if isinstance (value , np .number ):
224+ return value .item ()
225+ if isinstance (value , np .ndarray ):
226+ return value
227+ return value
228+
229+ return encode_basic_with_numpy
230+
231+ # Handle lists
232+ if isinstance (variant , AnalyzedListType ):
233+ if variant .elem_type :
234+ elem_encoder = _make_encoder_closure (_get_type_info (variant .elem_type ))
235+ return (
236+ lambda value : [elem_encoder (v ) for v in value ]
237+ if isinstance (value , (list , tuple ))
238+ else value
93239 )
94- for f in fields
95- ]
96-
97- if is_namedtuple_type (type (value )):
98- annotations = type (value ).__annotations__
99- return [
100- _encode_engine_value_core (
101- getattr (value , name ),
102- type_info = _get_type_info_safe (annotations .get (name ))
103- if annotations .get (name )
104- else None ,
240+ else :
241+ fallback_encoder = _make_encoder_closure (None )
242+ return (
243+ lambda value : [fallback_encoder (v ) for v in value ]
244+ if isinstance (value , (list , tuple ))
245+ else value
105246 )
106- for name in value ._fields
107- ]
108247
109- if isinstance (value , np .number ):
110- return value .item ()
248+ # Handle dicts
249+ if isinstance (variant , AnalyzedDictType ):
250+ if variant .value_type :
251+ value_encoder = _make_encoder_closure (_get_type_info (variant .value_type ))
252+ untyped_encoder = _make_encoder_closure (None )
111253
112- if isinstance (value , np .ndarray ):
113- return value
254+ def encode_dict (value : Any ) -> Any :
255+ if not isinstance (value , dict ):
256+ return value
114257
115- if isinstance (value , (list , tuple )):
116- if (
117- type_info
118- and isinstance (type_info .variant , AnalyzedListType )
119- and type_info .variant .elem_type
120- ):
121- elem_type_info = _get_type_info_safe (type_info .variant .elem_type )
122- return [
123- _encode_engine_value_core (
124- v ,
125- type_info = elem_type_info ,
126- )
127- for v in value
128- ]
258+ # Handle empty dict
259+ if not value :
260+ return []
261+
262+ # Handle KTable
263+ first_val = next (iter (value .values ()))
264+ if is_struct_type (type (first_val )):
265+ return [
266+ [untyped_encoder (k )] + untyped_encoder (v )
267+ for k , v in value .items ()
268+ ]
269+
270+ # Handle regular dict
271+ return {k : value_encoder (v ) for k , v in value .items ()}
272+
273+ return encode_dict
129274 else :
130- return [_encode_engine_value_core (v , type_info = None ) for v in value ]
131-
132- if isinstance (value , dict ):
133- # Determine if this is a JSON type
134- is_json_type = False
135- if type_info and isinstance (type_info .variant , AnalyzedBasicType ):
136- is_json_type = type_info .variant .kind == "Json"
137-
138- # Handle empty dict
139- if not value :
140- return value if (not type_info or is_json_type ) else []
141-
142- # Handle KTable
143- first_val = next (iter (value .values ()))
144- if is_struct_type (type (first_val )):
145- return [
146- [_encode_engine_value_core (k , type_info = None )]
147- + _encode_engine_value_core (v , type_info = None )
148- for k , v in value .items ()
149- ]
275+ return lambda value : value
150276
151- # Handle regular dict
152- if (
153- type_info
154- and isinstance (type_info .variant , AnalyzedDictType )
155- and type_info .variant .value_type
156- ):
157- value_type_info = _get_type_info_safe (type_info .variant .value_type )
158- return {
159- k : _encode_engine_value_core (
160- v ,
161- type_info = value_type_info ,
277+ # Handle struct types
278+ if isinstance (variant , AnalyzedStructType ):
279+ struct_type = variant .struct_type
280+
281+ if dataclasses .is_dataclass (struct_type ):
282+ fields = dataclasses .fields (struct_type )
283+ field_encoders = [
284+ _make_encoder_closure (_get_type_info (f .type )) for f in fields
285+ ]
286+ field_names = [f .name for f in fields ]
287+
288+ def encode_dataclass (value : Any ) -> Any :
289+ if not dataclasses .is_dataclass (value ):
290+ return value
291+ return [
292+ encoder (getattr (value , name ))
293+ for encoder , name in zip (field_encoders , field_names )
294+ ]
295+
296+ return encode_dataclass
297+
298+ elif is_namedtuple_type (struct_type ):
299+ annotations = struct_type .__annotations__
300+ field_names = list (getattr (struct_type , "_fields" , ()))
301+ field_encoders = [
302+ _make_encoder_closure (
303+ _get_type_info (annotations .get (name ))
304+ if annotations .get (name )
305+ else None
162306 )
163- for k , v in value .items ()
164- }
307+ for name in field_names
308+ ]
309+
310+ def encode_namedtuple (value : Any ) -> Any :
311+ if not is_namedtuple_type (type (value )):
312+ return value
313+ return [
314+ encoder (getattr (value , name ))
315+ for encoder , name in zip (field_encoders , field_names )
316+ ]
317+
318+ return encode_namedtuple
319+
320+ # Handle numpy types
321+ def encode_with_numpy_check (value : Any ) -> Any :
322+ if isinstance (value , np .number ):
323+ return value .item ()
324+ if isinstance (value , np .ndarray ):
325+ return value
326+ if isinstance (value , tuple ):
327+ return [_make_encoder_closure (None )(v ) for v in value ]
328+ return value
329+
330+ return encode_with_numpy_check
331+
332+
333+ def make_engine_value_encoder (type_hint : Type [Any ] | str ) -> Callable [[Any ], Any ]:
334+ """
335+ Create an encoder closure for converting Python values to engine values.
336+
337+ Args:
338+ type_hint: Type annotation for the values to encode
165339
166- return value
340+ Returns:
341+ A closure that encodes Python values to engine values
342+ """
343+ type_info = _get_type_info (type_hint )
344+ if isinstance (type_info .variant , AnalyzedUnknownType ):
345+ raise ValueError (f"Type annotation `{ type_info .core_type } ` is unsupported" )
346+
347+ return _make_encoder_closure (type_info )
167348
168349
169350def encode_engine_value (value : Any , type_hint : Type [Any ] | str ) -> Any :
@@ -177,12 +358,8 @@ def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
177358 Returns:
178359 The encoded engine value
179360 """
180- # Analyze type once and reuse the result
181- type_info = _get_type_info_safe (type_hint )
182- if isinstance (type_info .variant , AnalyzedUnknownType ):
183- raise ValueError (f"Type annotation `{ type_info .core_type } ` is unsupported" )
184-
185- return _encode_engine_value_core (value , type_info )
361+ encoder = make_engine_value_encoder (type_hint )
362+ return encoder (value )
186363
187364
188365def make_engine_value_decoder (
0 commit comments