We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 81f8444 commit 4eb25a3Copy full SHA for 4eb25a3
onnxscript/function_libs/torch_lib/ops/core.py
@@ -8764,6 +8764,11 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
8764
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
8765
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
8766
8767
+ if isinstance(self.shape[dim], int):
8768
+ # We can create a definitive split op if the input shape is static
8769
+ outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
8770
+ return [op.Squeeze(out, [self.shape[dim]]) for out in outputs]
8771
+
8772
return op.SplitToSequence(self, axis=dim, keepdims=False)
8773
8774
0 commit comments