Skip to content

Commit 6bbed3a

Browse files
authored
fix: skip None input arguments for batched custom functions (#1328)
1 parent 152d615 commit 6bbed3a

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

python/cocoindex/op.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,30 @@ async def prepare(self) -> None:
369369

370370
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
371371
decoded_args = []
372-
for arg_info, arg in zip(self._args_info, args):
373-
if arg_info.is_required and arg is None:
372+
skipped_idx: list[int] | None = None
373+
if op_args.batching:
374+
if len(args) != 1:
375+
raise ValueError(
376+
"Batching is only supported for single argument functions"
377+
)
378+
arg_info = self._args_info[0]
379+
if arg_info.is_required and args[0] is None:
374380
return None
375-
decoded_args.append(arg_info.decoder(arg))
381+
decoded = arg_info.decoder(args[0])
382+
if arg_info.is_required:
383+
skipped_idx = [i for i, arg in enumerate(decoded) if arg is None]
384+
if len(skipped_idx) > 0:
385+
decoded = [v for v in decoded if v is not None]
386+
if len(decoded) == 0:
387+
return [None for _ in range(len(skipped_idx))]
388+
else:
389+
skipped_idx = None
390+
decoded_args.append(decoded)
391+
else:
392+
for arg_info, arg in zip(self._args_info, args):
393+
if arg_info.is_required and arg is None:
394+
return None
395+
decoded_args.append(arg_info.decoder(arg))
376396

377397
decoded_kwargs = {}
378398
for kwarg_name, arg in kwargs.items():
@@ -387,7 +407,25 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
387407

388408
assert self._acall is not None
389409
output = await self._acall(*decoded_args, **decoded_kwargs)
390-
return self._result_encoder(output)
410+
411+
if skipped_idx is None:
412+
return self._result_encoder(output)
413+
414+
padded_output: list[Any] = []
415+
next_idx = 0
416+
for v in output:
417+
while next_idx < len(skipped_idx) and skipped_idx[next_idx] == len(
418+
padded_output
419+
):
420+
next_idx += 1
421+
padded_output.append(None)
422+
padded_output.append(v)
423+
424+
while next_idx < len(skipped_idx):
425+
padded_output.append(None)
426+
next_idx += 1
427+
428+
return self._result_encoder(padded_output)
391429

392430
def enable_cache(self) -> bool:
393431
return op_args.cache

0 commit comments

Comments
 (0)