Skip to content

Commit 093327b

Browse files
authored
Merge pull request #395 from guoruoqian/unsqueeze_fix_bug
fix a bug when unsqueeze(self, dim)'s dim is negative, the result of …
2 parents 7cd52cf + 22db4b7 commit 093327b

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

core/conversion/converters/impl/unsqueeze.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@ auto unsqueeze_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().p
2121
auto dim = args[1].unwrapToInt();
2222

2323
auto selfDim = util::toVec(self->getDimensions());
24+
int64_t nbDims = selfDim.size();
25+
TRTORCH_CHECK(
26+
dim <= nbDims && dim >= -(nbDims + 1),
27+
"Dimension out of range (expected to be in range of [" << -(nbDims + 1) << ", " << nbDims << "], but got "
28+
<< dim << ")");
2429
if (dim < 0) {
25-
dim = selfDim.size() + dim;
30+
dim = nbDims + 1 + dim;
2631
}
2732

2833
auto shuffle_layer = ctx->net->addShuffle(*self);

tests/core/conversion/converters/test_unsqueeze.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,26 @@ TEST(Converters, ATenUnsqueezeConvertsCorrectly) {
2222
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
2323
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
2424

25+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
26+
}
27+
28+
TEST(Converters, ATenUnsqueezeNegativeDimConvertsCorrectly) {
29+
const auto graph = R"IR(
30+
graph(%0 : Tensor):
31+
%1 : int = prim::Constant[value=-4]()
32+
%2 : Tensor = aten::unsqueeze(%0, %1)
33+
return (%2))IR";
34+
35+
auto g = std::make_shared<torch::jit::Graph>();
36+
torch::jit::parseIR(graph, &*g);
37+
38+
auto in = at::randint(1, 10, {2, 3, 3}, {at::kCUDA});
39+
40+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
41+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
42+
43+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
44+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
45+
2546
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
2647
}

0 commit comments

Comments
 (0)