88from enum import Enum
99from threading import Lock
1010
11- from .typing import encode_enriched_type , analyze_type_info , COLLECTION_TYPES
12- from .convert import to_engine_value
11+ from .typing import encode_enriched_type
12+ from .convert import to_engine_value , make_engine_value_converter
1313from . import _engine
1414
15-
1615class OpCategory (Enum ):
1716 """The category of the operation."""
1817 FUNCTION = "function"
@@ -60,75 +59,6 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
6059 result_type = executor .analyze (* args , ** kwargs )
6160 return (encode_enriched_type (result_type ), executor )
6261
63- def _make_engine_struct_value_converter (
64- field_path : list [str ],
65- src_fields : list [dict [str , Any ]],
66- dst_dataclass_type : type ,
67- ) -> Callable [[list ], Any ]:
68- """Make a converter from an engine field values to a Python value."""
69-
70- src_name_to_idx = {f ['name' ]: i for i , f in enumerate (src_fields )}
71- def make_closure_for_value (name : str , param : inspect .Parameter ) -> Callable [[list ], Any ]:
72- src_idx = src_name_to_idx .get (name )
73- if src_idx is not None :
74- field_path .append (f'.{ name } ' )
75- field_converter = _make_engine_value_converter (
76- field_path , src_fields [src_idx ]['type' ], param .annotation )
77- field_path .pop ()
78- return lambda values : field_converter (values [src_idx ])
79-
80- default_value = param .default
81- if default_value is inspect .Parameter .empty :
82- raise ValueError (
83- f"Field without default value is missing in input: { '' .join (field_path )} " )
84-
85- return lambda _ : default_value
86-
87- field_value_converters = [
88- make_closure_for_value (name , param )
89- for (name , param ) in inspect .signature (dst_dataclass_type ).parameters .items ()]
90-
91- return lambda values : dst_dataclass_type (
92- * (converter (values ) for converter in field_value_converters ))
93-
94- def _make_engine_value_converter (
95- field_path : list [str ],
96- src_type : dict [str , Any ],
97- dst_annotation ,
98- ) -> Callable [[Any ], Any ]:
99- """Make a converter from an engine value to a Python value."""
100-
101- src_type_kind = src_type ['kind' ]
102-
103- if dst_annotation is inspect .Parameter .empty :
104- if src_type_kind == 'Struct' or src_type_kind in COLLECTION_TYPES :
105- raise ValueError (f"Missing type annotation for `{ '' .join (field_path )} `."
106- f"It's required for { src_type_kind } type." )
107- return lambda value : value
108-
109- dst_type_info = analyze_type_info (dst_annotation )
110-
111- if src_type_kind != dst_type_info .kind :
112- raise ValueError (
113- f"Type mismatch for `{ '' .join (field_path )} `: "
114- f"passed in { src_type_kind } , declared { dst_annotation } ({ dst_type_info .kind } )" )
115-
116- if dst_type_info .dataclass_type is not None :
117- return _make_engine_struct_value_converter (
118- field_path , src_type ['fields' ], dst_type_info .dataclass_type )
119-
120- if src_type_kind in COLLECTION_TYPES :
121- field_path .append ('[*]' )
122- elem_type_info = analyze_type_info (dst_type_info .elem_type )
123- if elem_type_info .dataclass_type is None :
124- raise ValueError (f"Type mismatch for `{ '' .join (field_path )} `: "
125- f"declared `{ dst_type_info .kind } `, a dataclass type expected" )
126- elem_converter = _make_engine_struct_value_converter (
127- field_path , src_type ['row' ]['fields' ], elem_type_info .dataclass_type )
128- field_path .pop ()
129- return lambda value : [elem_converter (v ) for v in value ] if value is not None else None
130-
131- return lambda value : value
13262
13363_gpu_dispatch_lock = Lock ()
13464
@@ -190,7 +120,7 @@ def analyze(self, *args, **kwargs):
190120 raise ValueError (
191121 f"Too many positional arguments passed in: { len (args )} > { next_param_idx } " )
192122 self ._args_converters .append (
193- _make_engine_value_converter (
123+ make_engine_value_converter (
194124 [arg_name ], arg .value_type ['type' ], arg_param .annotation ))
195125 if arg_param .kind != inspect .Parameter .VAR_POSITIONAL :
196126 next_param_idx += 1
@@ -207,7 +137,7 @@ def analyze(self, *args, **kwargs):
207137 if expected_arg is None :
208138 raise ValueError (f"Unexpected keyword argument passed in: { kwarg_name } " )
209139 arg_param = expected_arg [1 ]
210- self ._kwargs_converters [kwarg_name ] = _make_engine_value_converter (
140+ self ._kwargs_converters [kwarg_name ] = make_engine_value_converter (
211141 [kwarg_name ], kwarg .value_type ['type' ], arg_param .annotation )
212142
213143 missing_args = [name for (name , arg ) in expected_kwargs
0 commit comments