Skip to content

Commit b5d79ea

Browse files
fix(torchlib): aten::unbind.int uses SplitToSequence(keepdims=False) to match PyTorch shapes
1 parent b2d94fe commit b5d79ea

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

docs/test/test_documentation_examples.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,26 @@ def test(*relpath):
5656

5757
if __name__ == "__main__":
5858
unittest.main(verbosity=2)
59+
60+
#changed code
61+
import torch
62+
from onnxscript.function_libs.torch_lib import ops
63+
from onnxscript import evaluator
64+
65+
def test_unbind_matches_torch():
66+
x_torch = torch.randn(3, 4)
67+
y_torch = torch.unbind(x_torch, dim=1)
68+
69+
# Convert input to NumPy for ONNXScript
70+
x_np = x_torch.detach().cpu().numpy()
71+
72+
# Run in eager mode
73+
eager = evaluator.default()
74+
y_onnx = eager.eval_function(ops.core.aten_unbind, (x_np,), {"dim": 1})
75+
76+
# Compare number of outputs
77+
assert len(y_torch) == len(y_onnx)
78+
79+
# Compare shapes
80+
for a, b in zip(y_torch, y_onnx):
81+
assert a.shape == tuple(b.shape), f"Shape mismatch: {a.shape} vs {b.shape}"

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8618,10 +8618,13 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
86188618

86198619
@torch_op("aten::unbind.int")
86208620
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
8621-
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
8621+
# Use opset18 so keepdims=False works properly
8622+
shape = op.Shape(self)
8623+
dim_size = op.Gather(shape, op.Constant(value_int=dim), axis=0)
86228624

8623-
split_sizes = op.Constant(value_int=1)
8624-
return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False)
8625+
# Split into slices of size 1 along `dim`, dropping the axis
8626+
split_size = op.Constant(value_int=1)
8627+
return op.SplitToSequence(self, split_size, axis=dim, keepdims=False)
86258628

86268629

86278630
@torch_op("aten::unflatten.int", trace_only=True)

0 commit comments

Comments
 (0)