File tree Expand file tree Collapse file tree 2 files changed +8
-1
lines changed
onnxscript/function_libs/torch_lib/ops
tests/function_libs/torch_lib Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -8467,6 +8467,12 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
84678467def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
84688468 """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""
84698469
8470+ if isinstance(self.shape[dim], int) and not version_utils.torch_older_than("2.7"):
8471+ # We can create a definitive split op if the input shape is static
8472+ # Only torch>=2.7 supports correctly generating the correct number of outputs for Split
8473+ outputs = op.Split(self, axis=dim, num_outputs=self.shape[dim])
8474+ return [op.Squeeze(out, [dim]) for out in outputs]
8475+
84708476 return op.SplitToSequence(self, axis=dim, keepdims=False)
84718477
84728478
Original file line number Diff line number Diff line change 3939from torch.utils import _pytree as pytree
4040
4141import onnxscript
42+ from onnxscript._internal import version_utils
4243from tests.function_libs.torch_lib import (
4344 error_reproduction,
4445 ops_test_common,
@@ -202,7 +203,7 @@ def run_test_output_match(
202203 reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
203204 if (
204205 op.name.startswith("split")
205- or op.name.startswith("unbind")
206+ or ( op.name.startswith("unbind") and version_utils.torch_older_than("2.7") )
206207 or op.name
207208 in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"}
208209 ):
You can’t perform that action at this time.
0 commit comments