Skip to content

Commit f42da07

Browse files
authored
Merge branch 'main' into justinchu/consolidate-overloads
2 parents 7b34975 + 075fc4d commit f42da07

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8467,6 +8467,12 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
84678467
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
84688468
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
84698469

8470+
if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
8471+
# We can create a definitive split op if the input shape is static
8472+
# Only torch>=2.7 supports correctly generating the correct number of outputs for Split
8473+
outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
8474+
return [op.Squeeze(out, [dim]) for out in outputs]
8475+
84708476
return op.SplitToSequence(self, axis=dim, keepdims=False)
84718477

84728478

tests/function_libs/torch_lib/ops_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from torch.utils import _pytree as pytree
4040

4141
import onnxscript
42+
from onnxscript._internal import version_utils
4243
from tests.function_libs.torch_lib import (
4344
error_reproduction,
4445
ops_test_common,
@@ -202,7 +203,7 @@ def run_test_output_match(
202203
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
203204
if (
204205
op.name.startswith("split")
205-
or op.name.startswith("unbind")
206+
or (op.name.startswith("unbind") and version_utils.torch_older_than("2.7"))
206207
or op.name
207208
in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"}
208209
):

0 commit comments

Comments
 (0)