Skip to content

Commit d7c2794

Browse files
committed
support aten::roll
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent a6c27e0 commit d7c2794

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,34 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
6767
return true;
6868
}
6969

70+
nvinfer1::ITensor* roll(
71+
ConversionCtx* ctx,
72+
nvinfer1::ITensor* in,
73+
int shift,
74+
int dim,
75+
const std::vector<int64_t>& in_shape) {
76+
auto in_dim = in_shape[dim];
77+
78+
auto start = (in_dim - shift) % in_dim;
79+
// Behavior of % is different in C++ vs Python for negative numbers. This
80+
// corrects the difference.
81+
if (start < 0) {
82+
start = start + in_dim;
83+
}
84+
at::Tensor index0 = at::arange(start, in_dim, 1, torch::kInt32);
85+
at::Tensor index;
86+
if (start == 0) {
87+
index = index0;
88+
} else {
89+
at::Tensor index1 = at::arange(start, torch::kInt32);
90+
index = at::cat({index0, index1}, 0);
91+
}
92+
auto index_tensor = tensor_to_const(ctx, index);
93+
auto gather_layer = ctx->net->addGather(*in, *index_tensor, dim);
94+
auto out = gather_layer->getOutput(0);
95+
return out;
96+
}
97+
7098
auto select_registrations TORCHTRT_UNUSED =
7199
RegisterNodeConversionPatterns()
72100
.pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
@@ -202,6 +230,29 @@ auto select_registrations TORCHTRT_UNUSED =
202230

203231
return true;
204232
}})
233+
.pattern({"aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)",
234+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
235+
auto in = args[0].ITensor();
236+
auto shifts = args[1].unwrapToIntList().vec();
237+
auto dims = args[2].unwrapToIntList().vec();
238+
239+
TORCHTRT_CHECK(dims.size() == shifts.size(), "dims.size() should be equal to shifts.size()");
240+
if (ctx->input_is_dynamic) {
241+
TORCHTRT_THROW_ERROR("aten::roll is currently not support in dynamic input shape compilation");
242+
} else {
243+
auto in_shape = util::toVec(in->getDimensions());
244+
for (size_t i = 0; i < dims.size(); i++) {
245+
auto dim = dims[i] < 0 ? (in_shape.size() + dims[i]) : dims[i];
246+
TORCHTRT_CHECK(dim < in_shape.size(), "Dimension out of range");
247+
in = roll(ctx, in, shifts[i], dim, in_shape);
248+
}
249+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
250+
251+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
252+
253+
return true;
254+
}
255+
}})
205256
.pattern(
206257
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
207258
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

tests/core/conversion/converters/test_select.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,84 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) {
195195
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
196196
}
197197

198+
TEST(Converters, ATenRollConvertsCorrectly) {
199+
const auto graph = R"IR(
200+
graph(%1 : Tensor):
201+
%2 : int[] = prim::Constant[value=[1, 0, 3, 7]]()
202+
%3 : int[] = prim::Constant[value=[0, 1, 2, 3]]()
203+
%4 : Tensor = aten::roll(%1, %2, %3)
204+
return (%4))IR";
205+
206+
auto g = std::make_shared<torch::jit::Graph>();
207+
208+
torch::jit::parseIR(graph, g.get());
209+
210+
// Run Pytorch
211+
auto in = at::randint(1, 10, {2, 3, 4, 5}, {at::kCUDA});
212+
213+
auto jit_in = at::clone(in);
214+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
215+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
216+
217+
auto trt_in = at::clone(in);
218+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
219+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
220+
221+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
222+
}
223+
224+
TEST(Converters, ATenRollShiftsNegativeConvertsCorrectly) {
225+
const auto graph = R"IR(
226+
graph(%1 : Tensor):
227+
%2 : int[] = prim::Constant[value=[0, -3, -3]]()
228+
%3 : int[] = prim::Constant[value=[1, 2, 3]]()
229+
%4 : Tensor = aten::roll(%1, %2, %3)
230+
return (%4))IR";
231+
232+
auto g = std::make_shared<torch::jit::Graph>();
233+
234+
torch::jit::parseIR(graph, g.get());
235+
236+
// Run Pytorch
237+
auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});
238+
239+
auto jit_in = at::clone(in);
240+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
241+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
242+
243+
auto trt_in = at::clone(in);
244+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
245+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
246+
247+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
248+
}
249+
250+
TEST(Converters, ATenRollDimsNegativeConvertsCorrectly) {
251+
const auto graph = R"IR(
252+
graph(%1 : Tensor):
253+
%2 : int[] = prim::Constant[value=[0, -3, -3]]()
254+
%3 : int[] = prim::Constant[value=[1, 2, -1]]()
255+
%4 : Tensor = aten::roll(%1, %2, %3)
256+
return (%4))IR";
257+
258+
auto g = std::make_shared<torch::jit::Graph>();
259+
260+
torch::jit::parseIR(graph, g.get());
261+
262+
// Run Pytorch
263+
auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});
264+
265+
auto jit_in = at::clone(in);
266+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
267+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
268+
269+
auto trt_in = at::clone(in);
270+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
271+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
272+
273+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
274+
}
275+
198276
TEST(Converters, ATenSliceConvertsCorrectly) {
199277
const auto graph = R"IR(
200278
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)