Skip to content

Commit e9a3123

Browse files
committed
Create e2e test for aten.roll
Signed-off-by: Justin Chu <[email protected]>
1 parent aa2cf4a commit e9a3123

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,21 @@ def forward(self, q, k, v):
225225
)
226226
_testing.assert_onnx_program(onnx_program)
227227

228+
def test_roll(self):
229+
class Model(torch.nn.Module):
230+
def forward(self, x):
231+
x = torch.roll(x, 1)
232+
x = torch.roll(x, 1, 0)
233+
x = torch.roll(x, -1, 0)
234+
x = torch.roll(x, shifts=(2, 1), dims=(0, 1))
235+
return x
236+
237+
model = Model()
238+
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
239+
240+
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
241+
_testing.assert_onnx_program(onnx_program)
242+
228243

229244
if __name__ == "__main__":
230245
unittest.main()

0 commit comments

Comments
 (0)