@@ -195,6 +195,84 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) {
195
195
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
196
196
}
197
197
198
+ TEST (Converters, ATenRollConvertsCorrectly) {
199
+ const auto graph = R"IR(
200
+ graph(%1 : Tensor):
201
+ %2 : int[] = prim::Constant[value=[1, 0, 3, 7]]()
202
+ %3 : int[] = prim::Constant[value=[0, 1, 2, 3]]()
203
+ %4 : Tensor = aten::roll(%1, %2, %3)
204
+ return (%4))IR" ;
205
+
206
+ auto g = std::make_shared<torch::jit::Graph>();
207
+
208
+ torch::jit::parseIR (graph, g.get ());
209
+
210
+ // Run Pytorch
211
+ auto in = at::randint (1 , 10 , {2 , 3 , 4 , 5 }, {at::kCUDA });
212
+
213
+ auto jit_in = at::clone (in);
214
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
215
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
216
+
217
+ auto trt_in = at::clone (in);
218
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
219
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
220
+
221
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
222
+ }
223
+
224
+ TEST (Converters, ATenRollShiftsNegativeConvertsCorrectly) {
225
+ const auto graph = R"IR(
226
+ graph(%1 : Tensor):
227
+ %2 : int[] = prim::Constant[value=[0, -3, -3]]()
228
+ %3 : int[] = prim::Constant[value=[1, 2, 3]]()
229
+ %4 : Tensor = aten::roll(%1, %2, %3)
230
+ return (%4))IR" ;
231
+
232
+ auto g = std::make_shared<torch::jit::Graph>();
233
+
234
+ torch::jit::parseIR (graph, g.get ());
235
+
236
+ // Run Pytorch
237
+ auto in = at::randint (1 , 10 , {1 , 3 , 4 , 5 }, {at::kCUDA });
238
+
239
+ auto jit_in = at::clone (in);
240
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
241
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
242
+
243
+ auto trt_in = at::clone (in);
244
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
245
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
246
+
247
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
248
+ }
249
+
250
+ TEST (Converters, ATenRollDimsNegativeConvertsCorrectly) {
251
+ const auto graph = R"IR(
252
+ graph(%1 : Tensor):
253
+ %2 : int[] = prim::Constant[value=[0, -3, -3]]()
254
+ %3 : int[] = prim::Constant[value=[1, 2, -1]]()
255
+ %4 : Tensor = aten::roll(%1, %2, %3)
256
+ return (%4))IR" ;
257
+
258
+ auto g = std::make_shared<torch::jit::Graph>();
259
+
260
+ torch::jit::parseIR (graph, g.get ());
261
+
262
+ // Run Pytorch
263
+ auto in = at::randint (1 , 10 , {1 , 3 , 4 , 5 }, {at::kCUDA });
264
+
265
+ auto jit_in = at::clone (in);
266
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
267
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
268
+
269
+ auto trt_in = at::clone (in);
270
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
271
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
272
+
273
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
274
+ }
275
+
198
276
TEST (Converters, ATenSliceConvertsCorrectly) {
199
277
const auto graph = R"IR(
200
278
graph(%x.1 : Tensor):
0 commit comments