Skip to content

Commit b57875b

Browse files
committed
Fix aten_unbind for torch >= 2.7 dynamo export
Replace Split op with explicit Slice operations to fix TypeError when unbind is called during ONNX export with dynamo=True. The Split op with num_outputs parameter returns a non-iterable SymbolicTensor instead of a sequence, causing the list comprehension to fail. The fix uses individual Slice + Squeeze operations for each output, which properly handles symbolic tensors during graph construction. Fixes pytorch/pytorch#168969
1 parent 9dbf685 commit b57875b

File tree

1 file changed

+15
-8
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+15
-8
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9201,15 +9201,22 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
92019201
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
92029202

92039203
if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
9204-
# We can create a definitive split op if the input shape is static
9205-
# Only torch>=2.7 supports correctly generating the correct number of outputs for Split
9204+
# For torch>=2.7 with static shapes, use explicit Slice operations
9205+
# to avoid issues with Split returning a non-iterable SymbolicTensor
92069206
num_outputs = self.shape[dim]
9207-
if num_outputs != 1:
9208-
outputs = op.Split(self, axis=dim, num_outputs=num_outputs)
9209-
else:
9210-
outputs = [self]
9211-
9212-
return [op.Squeeze(out, [dim]) for out in outputs]
9207+
results = []
9208+
for i in range(num_outputs):
9209+
# Slice to get a single element at position i along dim
9210+
sliced = op.Slice(
9211+
self,
9212+
starts=op.Constant(value_ints=[i]),
9213+
ends=op.Constant(value_ints=[i + 1]),
9214+
axes=op.Constant(value_ints=[dim]),
9215+
)
9216+
# Squeeze to remove the dimension of size 1
9217+
squeezed = op.Squeeze(sliced, axes=[dim])
9218+
results.append(squeezed)
9219+
return results
92139220

92149221
return op.SplitToSequence(self, axis=dim, keepdims=False)
92159222

0 commit comments

Comments
 (0)