@@ -369,10 +369,30 @@ async def prepare(self) -> None:
369369
370370 async def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
371371 decoded_args = []
372- for arg_info , arg in zip (self ._args_info , args ):
373- if arg_info .is_required and arg is None :
372+ skipped_idx : list [int ] | None = None
373+ if op_args .batching :
374+ if len (args ) != 1 :
375+ raise ValueError (
376+ "Batching is only supported for single argument functions"
377+ )
378+ arg_info = self ._args_info [0 ]
379+ if arg_info .is_required and args [0 ] is None :
374380 return None
375- decoded_args .append (arg_info .decoder (arg ))
381+ decoded = arg_info .decoder (args [0 ])
382+ if arg_info .is_required :
383+ skipped_idx = [i for i , arg in enumerate (decoded ) if arg is None ]
384+ if len (skipped_idx ) > 0 :
385+ decoded = [v for v in decoded if v is not None ]
386+ if len (decoded ) == 0 :
387+ return [None for _ in range (len (skipped_idx ))]
388+ else :
389+ skipped_idx = None
390+ decoded_args .append (decoded )
391+ else :
392+ for arg_info , arg in zip (self ._args_info , args ):
393+ if arg_info .is_required and arg is None :
394+ return None
395+ decoded_args .append (arg_info .decoder (arg ))
376396
377397 decoded_kwargs = {}
378398 for kwarg_name , arg in kwargs .items ():
@@ -387,7 +407,25 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
387407
388408 assert self ._acall is not None
389409 output = await self ._acall (* decoded_args , ** decoded_kwargs )
390- return self ._result_encoder (output )
410+
411+ if skipped_idx is None :
412+ return self ._result_encoder (output )
413+
414+ padded_output : list [Any ] = []
415+ next_idx = 0
416+ for v in output :
417+ while next_idx < len (skipped_idx ) and skipped_idx [next_idx ] == len (
418+ padded_output
419+ ):
420+ next_idx += 1
421+ padded_output .append (None )
422+ padded_output .append (v )
423+
424+ while next_idx < len (skipped_idx ):
425+ padded_output .append (None )
426+ next_idx += 1
427+
428+ return self ._result_encoder (padded_output )
391429
392430 def enable_cache (self ) -> bool :
393431 return op_args .cache
0 commit comments