Skip to content

Commit f0ff3cb

Browse files
committed
fix(outputs): Cast to sequence if num_outputs=1
1 parent d80575d commit f0ff3cb

File tree

1 file changed

+5
-1
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+5
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8957,7 +8957,11 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
89578957
if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
89588958
# We can create a definitive split op if the input shape is static
89598959
# Only torch>=2.7 supports correctly generating the correct number of outputs for Split
8960-
outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
8960+
num_outputs = self.shape[dim]
8961+
outputs = op.Split(self, axis=dim, num_outputs=num_outputs)
8962+
if num_outputs == 1:
8963+
outputs = [outputs]
8964+
89618965
return [op.Squeeze(out, [dim]) for out in outputs]
89628966

89638967
return op.SplitToSequence(self, axis=dim, keepdims=False)

0 commit comments

Comments
 (0)