@@ -114,8 +114,8 @@ def __call__(
114114 ) -> tuple [dict [str , Any ], Executor ]:
115115 spec = _load_spec_from_engine (self ._spec_cls , spec )
116116 executor = self ._executor_cls (spec )
117- result_type = executor .analyze (* args , ** kwargs )
118- return (encode_enriched_type ( result_type ) , executor )
117+ result_type = executor .analyze_schema (* args , ** kwargs )
118+ return (result_type , executor )
119119
120120
121121_gpu_dispatch_lock = asyncio .Lock ()
@@ -156,6 +156,12 @@ def _to_async_call(call: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
156156 return lambda * args , ** kwargs : asyncio .to_thread (lambda : call (* args , ** kwargs ))
157157
158158
159+ @dataclasses .dataclass
160+ class _ArgInfo :
161+ decoder : Callable [[Any ], Any ]
162+ is_required : bool
163+
164+
159165def _register_op_factory (
160166 category : OpCategory ,
161167 expected_args : list [tuple [str , inspect .Parameter ]],
@@ -176,37 +182,54 @@ def behavior_version(self) -> int | None:
176182 return op_args .behavior_version
177183
178184 class _WrappedClass (executor_cls , _Fallback ): # type: ignore[misc]
179- _args_decoders : list [Callable [[ Any ], Any ] ]
180- _kwargs_decoders : dict [str , Callable [[ Any ], Any ] ]
185+ _args_info : list [_ArgInfo ]
186+ _kwargs_info : dict [str , _ArgInfo ]
181187 _acall : Callable [..., Awaitable [Any ]]
182188
183189 def __init__ (self , spec : Any ) -> None :
184190 super ().__init__ ()
185191 self .spec = spec
186192 self ._acall = _to_async_call (super ().__call__ )
187193
188- def analyze (
194+ def analyze_schema (
189195 self , * args : _engine .OpArgSchema , ** kwargs : _engine .OpArgSchema
190196 ) -> Any :
191197 """
192198 Analyze the spec and arguments. In this phase, argument types should be validated.
193199 It should return the expected result type for the current op.
194200 """
195- self ._args_decoders = []
196- self ._kwargs_decoders = {}
201+ self ._args_info = []
202+ self ._kwargs_info = {}
197203 attributes = []
198-
199- def process_attribute (arg_name : str , arg : _engine .OpArgSchema ) -> None :
204+ potentially_missing_required_arg = False
205+
206+ def process_arg (
207+ arg_name : str ,
208+ arg_param : inspect .Parameter ,
209+ actual_arg : _engine .OpArgSchema ,
210+ ) -> _ArgInfo :
211+ nonlocal potentially_missing_required_arg
200212 if op_args .arg_relationship is not None :
201213 related_attr , related_arg_name = op_args .arg_relationship
202214 if related_arg_name == arg_name :
203215 attributes .append (
204- TypeAttr (related_attr .value , arg .analyzed_value )
216+ TypeAttr (related_attr .value , actual_arg .analyzed_value )
205217 )
218+ type_info = analyze_type_info (arg_param .annotation )
219+ decoder = make_engine_value_decoder (
220+ [arg_name ], actual_arg .value_type ["type" ], type_info
221+ )
222+ is_required = not type_info .nullable
223+ if is_required and actual_arg .value_type .get ("nullable" , False ):
224+ potentially_missing_required_arg = True
225+ return _ArgInfo (
226+ decoder = decoder ,
227+ is_required = is_required ,
228+ )
206229
207230 # Match arguments with parameters.
208231 next_param_idx = 0
209- for arg in args :
232+ for actual_arg in args :
210233 if next_param_idx >= len (expected_args ):
211234 raise ValueError (
212235 f"Too many arguments passed in: { len (args )} > { len (expected_args )} "
@@ -219,20 +242,13 @@ def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
219242 raise ValueError (
220243 f"Too many positional arguments passed in: { len (args )} > { next_param_idx } "
221244 )
222- self ._args_decoders .append (
223- make_engine_value_decoder (
224- [arg_name ],
225- arg .value_type ["type" ],
226- analyze_type_info (arg_param .annotation ),
227- )
228- )
229- process_attribute (arg_name , arg )
245+ self ._args_info .append (process_arg (arg_name , arg_param , actual_arg ))
230246 if arg_param .kind != inspect .Parameter .VAR_POSITIONAL :
231247 next_param_idx += 1
232248
233249 expected_kwargs = expected_args [next_param_idx :]
234250
235- for kwarg_name , kwarg in kwargs .items ():
251+ for kwarg_name , actual_arg in kwargs .items ():
236252 expected_arg = next (
237253 (
238254 arg
@@ -254,12 +270,9 @@ def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
254270 f"Unexpected keyword argument passed in: { kwarg_name } "
255271 )
256272 arg_param = expected_arg [1 ]
257- self ._kwargs_decoders [kwarg_name ] = make_engine_value_decoder (
258- [kwarg_name ],
259- kwarg .value_type ["type" ],
260- analyze_type_info (arg_param .annotation ),
273+ self ._kwargs_info [kwarg_name ] = process_arg (
274+ kwarg_name , arg_param , actual_arg
261275 )
262- process_attribute (kwarg_name , kwarg )
263276
264277 missing_args = [
265278 name
@@ -280,32 +293,45 @@ def process_attribute(arg_name: str, arg: _engine.OpArgSchema) -> None:
280293 if len (missing_args ) > 0 :
281294 raise ValueError (f"Missing arguments: { ', ' .join (missing_args )} " )
282295
283- prepare_method = getattr (executor_cls , "analyze" , None )
284- if prepare_method is not None :
285- result = prepare_method (self , * args , ** kwargs )
296+ base_analyze_method = getattr (self , "analyze" , None )
297+ if base_analyze_method is not None :
298+ result = base_analyze_method (self , * args , ** kwargs )
286299 else :
287300 result = expected_return
288301 if len (attributes ) > 0 :
289302 result = Annotated [result , * attributes ]
290- return result
303+
304+ encoded_type = encode_enriched_type (result )
305+ if potentially_missing_required_arg :
306+ encoded_type ["nullable" ] = True
307+ return encoded_type
291308
292309 async def prepare (self ) -> None :
293310 """
294311 Prepare for execution.
295312 It's executed after `analyze` and before any `__call__` execution.
296313 """
297- setup_method = getattr (super (), "prepare" , None )
298- if setup_method is not None :
299- await _to_async_call (setup_method )()
314+ prepare_method = getattr (super (), "prepare" , None )
315+ if prepare_method is not None :
316+ await _to_async_call (prepare_method )()
300317
301318 async def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
302- decoded_args = (
303- decoder (arg ) for decoder , arg in zip (self ._args_decoders , args )
304- )
305- decoded_kwargs = {
306- arg_name : self ._kwargs_decoders [arg_name ](arg )
307- for arg_name , arg in kwargs .items ()
308- }
319+ decoded_args = []
320+ for arg_info , arg in zip (self ._args_info , args ):
321+ if arg_info .is_required and arg is None :
322+ return None
323+ decoded_args .append (arg_info .decoder (arg ))
324+
325+ decoded_kwargs = {}
326+ for kwarg_name , arg in kwargs .items ():
327+ kwarg_info = self ._kwargs_info .get (kwarg_name )
328+ if kwarg_info is None :
329+ raise ValueError (
330+ f"Unexpected keyword argument passed in: { kwarg_name } "
331+ )
332+ if kwarg_info .is_required and arg is None :
333+ return None
334+ decoded_kwargs [kwarg_name ] = kwarg_info .decoder (arg )
309335
310336 if op_args .gpu :
311337 # For GPU executions, data-level parallelism is applied, so we don't want to
0 commit comments