Skip to content

Commit 12599cf

Browse files
committed
Fix roll
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent c0abfbc commit 12599cf

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7223,11 +7223,106 @@ def aten_rnn_tanh_cell(
72237223
raise NotImplementedError()
72247224

72257225

7226-
# roll is decomposed by PyTorch
7226+
@torch_op("aten::roll", trace_only=True)
72277227
def aten_roll(self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()) -> TTensor:
72287228
"""roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"""
72297229

7230-
raise NotImplementedError()
7230+
if isinstance(shifts, int):
7231+
shifts = [shifts]
7232+
7233+
if isinstance(dims, int):
7234+
dims = [dims]
7235+
7236+
self_rank = len(self.shape)
7237+
if self_rank == 0:
7238+
return op.Identity(self)
7239+
elif self.shape[0] == 0: # empty tensor
7240+
return op.Identity(self)
7241+
7242+
# NOTE: In pytorch, default value of dims is an empty list.
7243+
if len(dims) == 0: # Empty sequence
7244+
assert len(shifts) == 1, "shifts should be a single integer if dims is empty"
7245+
return _aten_roll_shift_no_dim_onnx(self, shifts[0])
7246+
else:
7247+
assert len(shifts) == len(dims)
7248+
result = self
7249+
for i, shift in enumerate(shifts):
7250+
dim = dims[i]
7251+
result = _aten_roll_shift_and_dim_onnx(result, shift, dim)
7252+
return result
7253+
7254+
7255+
@torch_op("aten::roll", trace_only=True, complex=True)
7256+
def aten_roll_complex(
7257+
self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()
7258+
) -> TTensor:
7259+
"""roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"""
7260+
7261+
if isinstance(shifts, int):
7262+
shifts = [shifts]
7263+
7264+
if isinstance(dims, int):
7265+
dims = [dims]
7266+
7267+
self_rank = len(self.shape)
7268+
if self_rank == 1:
7269+
return op.Identity(self)
7270+
7271+
if self.shape[0] == 0: # empty tensor
7272+
return op.Identity(self)
7273+
7274+
self_real = op.Slice(self, [0], [1], axes=[-1])
7275+
self_imag = op.Slice(self, [1], [2], axes=[-1])
7276+
if not dims:
7277+
assert len(shifts) == 1, "shifts should be a single integer if dims is empty"
7278+
shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts[0])
7279+
shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts[0])
7280+
7281+
result = op.Concat(shift_real, shift_imag, axis=-1)
7282+
7283+
else:
7284+
assert len(shifts) == len(dims)
7285+
for i, dim in enumerate(dims):
7286+
self_real = _aten_roll_shift_and_dim_onnx(self_real, shifts[i], dim)
7287+
self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shifts[i], dim)
7288+
7289+
result = op.Concat(self_real, self_imag, axis=-1)
7290+
return result
7291+
7292+
7293+
def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: int) -> TTensor:
7294+
neg_1 = op.Constant(value_ints=[-1])
7295+
# flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D]
7296+
self_flatten = op.Reshape(self, neg_1)
7297+
# Compute slice length
7298+
if shift < 0:
7299+
# For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end
7300+
slice_length = op.Constant(value_ints=[-shift])
7301+
else:
7302+
# For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end
7303+
# The effect equals to move [D] to the beginning
7304+
slice_length = op.Size(self_flatten) - op.Constant(value_ints=[shift])
7305+
# Get second part of the tensor, e.g. [A,B,C]
7306+
suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length)
7307+
# Get first part of the tensor, e.g. [D]
7308+
prefix = op.Slice(self_flatten, slice_length, op.Reshape(op.Size(self_flatten), neg_1))
7309+
# Concat first+second together, e.g. [D,A,B,C]
7310+
result = op.Concat(prefix, suffix, axis=0)
7311+
return op.Reshape(result, op.Shape(self))
7312+
7313+
7314+
def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: int, dim: int) -> TTensor:
7315+
neg_1 = op.Constant(value_ints=[-1])
7316+
dim_tensor = op.Constant(value_ints=[dim])
7317+
if shift < 0:
7318+
slice_length = op.Constant(value_ints=[-shift])
7319+
else:
7320+
slice_length = op.Shape(self, start=dim, end=dim + 1) - op.Constant(value_ints=[shift])
7321+
# from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix
7322+
suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor)
7323+
prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor)
7324+
result = op.Concat(prefix, suffix, axis=dim)
7325+
return result
72317326

72327327

72337328
def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> TensorType:

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,17 @@ def _where_input_wrangler(
16991699
TorchLibOpInfo("ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d),
17001700
TorchLibOpInfo("ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec),
17011701
TorchLibOpInfo("ones_like", core_ops.aten_ones_like),
1702+
TorchLibOpInfo(
1703+
"roll",
1704+
core_ops.aten_roll,
1705+
input_wrangler=_roll_input_wrangler,
1706+
),
1707+
TorchLibOpInfo(
1708+
"roll",
1709+
core_ops.aten_roll_complex,
1710+
input_wrangler=_roll_input_wrangler,
1711+
complex=True,
1712+
),
17021713
TorchLibOpInfo(
17031714
"scatter_reduce",
17041715
core_ops.aten_scatter_reduce,

0 commit comments

Comments
 (0)