Skip to content

Commit 781e135

Browse files
committed
conditional
Signed-off-by: Justin Chu <[email protected]>
1 parent d905b2a commit 781e135

File tree

1 file changed

+35
-6
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+35
-6
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
)
5555
from onnxscript.onnx_opset import opset18 as op
5656
from onnxscript.onnx_types import TensorType
57+
from onnxscript._internal import version_utils
5758

5859
_INT64_MAX = 9223372036854775807
5960
_INT64_MIN = -9223372036854775808
@@ -1647,12 +1648,40 @@ def aten_choose_qparams_optimized(
16471648
raise NotImplementedError()
16481649

16491650

1650-
@torch_op("aten::chunk", trace_only=True)
1651-
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1652-
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1653-
if chunks == 1:
1654-
return op.Identity(self)
1655-
return op.Split(self, axis=dim, num_outputs=chunks)
1651+
if version_utils.torch_older_than("2.7.0"):
1652+
# PyTorch <2.7 does not support determining the number of outputs for the Split op
1653+
# https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6
1654+
@torch_op("aten::chunk")
1655+
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1656+
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1657+
# This will create a Sequence of tensors
1658+
neg_1 = op.Constant(value_ints=[-1])
1659+
# Get size of specified dim
1660+
self_shape = op.Shape(self)
1661+
dim_size = op.Gather(self_shape, dim, axis=0)
1662+
# Compute size/chunk to get the number of data in one chunk
1663+
num_per_chunk = op.Div(dim_size, chunks)
1664+
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]
1665+
1666+
# Compute real chunk number
1667+
num_chunk = op.Div(dim_size, num_per_chunk)
1668+
# Get something like [n, n, n, n, ...], total num_chunk
1669+
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))
1670+
1671+
remainder = op.Mod(dim_size, num_per_chunk)
1672+
if remainder > 0: # type: ignore[operator]
1673+
# Append the remainder to the [n, n, n, n, ..., r]
1674+
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)
1675+
1676+
return op.SplitToSequence(self, list_split, axis=dim)
1677+
else:
1678+
1679+
@torch_op("aten::chunk", trace_only=True)
1680+
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1681+
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1682+
if chunks == 1:
1683+
return op.Identity(self)
1684+
return op.Split(self, axis=dim, num_outputs=chunks)
16561685

16571686

16581687
@torch_op("aten::clamp", trace_only=True)

0 commit comments

Comments
 (0)