Skip to content

Commit c37a943

Browse files
committed
some tensor like at::randint(1,10,{2,3,3}) dont need to squeeze, add a judgment condition to handle this situation
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 6bb9fbf commit c37a943

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

core/conversion/converters/impl/squeeze.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ auto squeeze_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pat
2525
dim = selfDim.size() + dim;
2626
}
2727

28+
if (selfDim[dim] != 1) {
29+
auto identity = ctx->net->addIdentity(*self);
30+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0));
31+
32+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
33+
34+
return true;
35+
}
36+
2837
auto shuffle_layer = ctx->net->addShuffle(*self);
2938
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
3039
shuffle_layer->setReshapeDimensions(util::squeezeDims(self->getDimensions(), dim));

tests/core/conversion/converters/test_squeeze.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,26 @@ TEST(Converters, ATenSqueezeConvertsCorrectly) {
2222
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
2323
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
2424

25+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
26+
}
27+
28+
TEST(Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) {
29+
const auto graph = R"IR(
30+
graph(%0 : Tensor):
31+
%1 : int = prim::Constant[value=1]()
32+
%2 : Tensor = aten::squeeze(%0, %1)
33+
return (%2))IR";
34+
35+
auto g = std::make_shared<torch::jit::Graph>();
36+
torch::jit::parseIR(graph, &*g);
37+
38+
auto in = at::randint(1, 10, {2, 3, 3}, {at::kCUDA});
39+
40+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
41+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
42+
43+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
44+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
45+
2546
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
2647
}

0 commit comments

Comments
 (0)