Skip to content

Commit fe3e9b6

Browse files
authored
Merge pull request #393 from guoruoqian/squeeze_fix_bug
Fix bug in squeeze
2 parents 6bb9fbf + 2c9cf17 commit fe3e9b6

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

core/conversion/converters/impl/squeeze.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ auto squeeze_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pat
2525
dim = selfDim.size() + dim;
2626
}
2727

28+
if (selfDim[dim] != 1) {
29+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], self);
30+
31+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
32+
33+
return true;
34+
}
35+
2836
auto shuffle_layer = ctx->net->addShuffle(*self);
2937
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
3038
shuffle_layer->setReshapeDimensions(util::squeezeDims(self->getDimensions(), dim));

tests/core/conversion/converters/test_squeeze.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,35 @@ TEST(Converters, ATenSqueezeConvertsCorrectly) {
2222
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
2323
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
2424

25+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
26+
}
27+
28+
TEST(Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) {
29+
const auto graph = R"IR(
30+
graph(%0 : Tensor, %1 : Tensor):
31+
%2 : int = prim::Constant[value=1]()
32+
%2.1 : Tensor = aten::add(%0, %1, %2)
33+
%3 : Tensor = aten::squeeze(%2.1, %2)
34+
%4 : Tensor = aten::add(%3, %1, %2)
35+
return (%4))IR";
36+
37+
auto g = std::make_shared<torch::jit::Graph>();
38+
torch::jit::parseIR(graph, &*g);
39+
40+
auto in = at::randint(1, 10, {2, 3, 3}, {at::kCUDA});
41+
auto in_add = at::randint(1, 10, {2, 3, 3}, {at::kCUDA});
42+
43+
auto jit_in = at::clone(in);
44+
auto jit_in_add = at::clone(in_add);
45+
46+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
47+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in, jit_in_add});
48+
49+
auto trt_in = at::clone(jit_in);
50+
auto trt_in_add = at::clone(jit_in_add);
51+
52+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
53+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in, trt_in_add});
54+
2555
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
2656
}

0 commit comments

Comments
 (0)