|
54 | 54 | ) |
55 | 55 | from onnxscript.onnx_opset import opset18 as op |
56 | 56 | from onnxscript.onnx_types import TensorType |
| 57 | +from onnxscript._internal import version_utils |
57 | 58 |
|
58 | 59 | _INT64_MAX = 9223372036854775807 |
59 | 60 | _INT64_MIN = -9223372036854775808 |
@@ -1647,12 +1648,40 @@ def aten_choose_qparams_optimized( |
1647 | 1648 | raise NotImplementedError() |
1648 | 1649 |
|
1649 | 1650 |
|
1650 | | -@torch_op("aten::chunk", trace_only=True) |
1651 | | -def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: |
1652 | | - """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" |
1653 | | - if chunks == 1: |
1654 | | - return op.Identity(self) |
1655 | | - return op.Split(self, axis=dim, num_outputs=chunks) |
| 1651 | +if version_utils.torch_older_than("2.7.0"): |
| 1652 | + # PyTorch <2.7 does not support determining the number of outputs for the Split op |
| 1653 | + # https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6 |
| 1654 | + @torch_op("aten::chunk") |
| 1655 | + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: |
| 1656 | + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" |
| 1657 | + # This will create a Sequence of tensors |
| 1658 | + neg_1 = op.Constant(value_ints=[-1]) |
| 1659 | + # Get size of specified dim |
| 1660 | + self_shape = op.Shape(self) |
| 1661 | + dim_size = op.Gather(self_shape, dim, axis=0) |
| 1662 | + # Compute size/chunk to get the number of data in one chunk |
| 1663 | + num_per_chunk = op.Div(dim_size, chunks) |
| 1664 | + num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] |
| 1665 | + |
| 1666 | + # Compute real chunk number |
| 1667 | + num_chunk = op.Div(dim_size, num_per_chunk) |
| 1668 | + # Get something like [n, n, n, n, ...], total num_chunk |
| 1669 | + list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) |
| 1670 | + |
| 1671 | + remainder = op.Mod(dim_size, num_per_chunk) |
| 1672 | + if remainder > 0: # type: ignore[operator] |
| 1673 | + # Append the remainder to the [n, n, n, n, ..., r] |
| 1674 | + list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) |
| 1675 | + |
| 1676 | + return op.SplitToSequence(self, list_split, axis=dim) |
| 1677 | +else: |
| 1678 | + |
| 1679 | + @torch_op("aten::chunk", trace_only=True) |
| 1680 | + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: |
| 1681 | + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" |
| 1682 | + if chunks == 1: |
| 1683 | + return op.Identity(self) |
| 1684 | + return op.Split(self, axis=dim, num_outputs=chunks) |
1656 | 1685 |
|
1657 | 1686 |
|
1658 | 1687 | @torch_op("aten::clamp", trace_only=True) |
|
0 commit comments