Skip to content

Commit 958be30

Browse files
authored
Merge pull request #400 from guoruoqian/select_int_fix_bug
Fix bug that when the dim is negative or bigger than one in selec…
2 parents 093327b + 3ab35c4 commit 958be30

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ auto select_registrations TRTORCH_UNUSED =
6868
.pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
6969
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
7070
auto in = args[0].ITensor();
71+
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
7172
auto axis = args[1].unwrapToInt();
73+
axis = axis < 0 ? axis + maxDim : axis;
7274
auto ind = (int32_t)args[2].unwrapToInt();
7375

7476
// index to access needs to be an at::Tensor
@@ -89,7 +91,7 @@ auto select_registrations TRTORCH_UNUSED =
8991
// IShuffleLayer removes redundant dimensions
9092
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
9193
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
92-
shuffle_layer->setReshapeDimensions(util::unpadDims(gather_out->getDimensions()));
94+
shuffle_layer->setReshapeDimensions(util::squeezeDims(gather_out->getDimensions(), axis));
9395
shuffle_layer->setName(util::node_info(n).c_str());
9496
auto shuffle_out = shuffle_layer->getOutput(0);
9597

tests/core/conversion/converters/test_select.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,59 @@ TEST(Converters, ATenSelectIntConvertsCorrectly) {
3131
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
3232
}
3333

34+
TEST(Converters, ATenSelectIntDimIsOneConvertsCorrectly) {
35+
const auto graph = R"IR(
36+
graph(%0 : Tensor):
37+
%2 : int = prim::Constant[value=1]()
38+
%3 : int = prim::Constant[value=0]()
39+
%4 : Tensor = aten::select(%0, %2, %3)
40+
return (%4))IR";
41+
42+
auto g = std::make_shared<torch::jit::Graph>();
43+
44+
torch::jit::parseIR(graph, &*g);
45+
46+
auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
47+
48+
auto jit_in = at::clone(in);
49+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
50+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
51+
52+
auto trt_in = at::clone(in);
53+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
54+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
55+
56+
// In order to check whether shape match that we don't do reshape.
57+
// E.g. x = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}), then aten::select(x, 1, 0). We should get a tensor y with
58+
// shape {4, 4} instead of a tensor with shape {4, 1, 4}.
59+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
60+
}
61+
62+
TEST(Converters, ATenSelectIntDimNegativeConvertsCorrectly) {
63+
const auto graph = R"IR(
64+
graph(%0 : Tensor):
65+
%2 : int = prim::Constant[value=-2]()
66+
%3 : int = prim::Constant[value=0]()
67+
%4 : Tensor = aten::select(%0, %2, %3)
68+
return (%4))IR";
69+
70+
auto g = std::make_shared<torch::jit::Graph>();
71+
72+
torch::jit::parseIR(graph, &*g);
73+
74+
auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
75+
76+
auto jit_in = at::clone(in);
77+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
78+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
79+
80+
auto trt_in = at::clone(in);
81+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
82+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
83+
84+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
85+
}
86+
3487
TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) {
3588
const auto graph = R"IR(
3689
graph(%0 : Tensor):

0 commit comments

Comments
 (0)