Skip to content

Commit 4eb25a3

Browse files
authored
Simplify aten_unbind when shape is static
Add static shape handling to aten_unbind function
1 parent 81f8444 commit 4eb25a3

File tree

1 file changed

+5
-0
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+5
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8764,6 +8764,11 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
87648764
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
87658765
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
87668766

8767+
if isinstance(self.shape[dim], int):
8768+
# We can create a definitive split op if the input shape is static
8769+
outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
8770+
return [op.Squeeze(out, [self.shape[dim]]) for out in outputs]
8771+
87678772
return op.SplitToSequence(self, axis=dim, keepdims=False)
87688773

87698774

0 commit comments

Comments
 (0)