Skip to content

Commit 2c9cf17

Browse files
committed
rm identity layer and reassoicate the input ITensor with the output node when we arent going to do anything
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent c37a943 commit 2c9cf17

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

core/conversion/converters/impl/squeeze.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ auto squeeze_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pat
2626
}
2727

2828
if (selfDim[dim] != 1) {
29-
auto identity = ctx->net->addIdentity(*self);
30-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity->getOutput(0));
29+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], self);
3130

3231
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
3332

tests/core/conversion/converters/test_squeeze.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,30 @@ TEST(Converters, ATenSqueezeConvertsCorrectly) {
2727

2828
TEST(Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) {
2929
const auto graph = R"IR(
30-
graph(%0 : Tensor):
31-
%1 : int = prim::Constant[value=1]()
32-
%2 : Tensor = aten::squeeze(%0, %1)
33-
return (%2))IR";
30+
graph(%0 : Tensor, %1 : Tensor):
31+
%2 : int = prim::Constant[value=1]()
32+
%2.1 : Tensor = aten::add(%0, %1, %2)
33+
%3 : Tensor = aten::squeeze(%2.1, %2)
34+
%4 : Tensor = aten::add(%3, %1, %2)
35+
return (%4))IR";
3436

3537
auto g = std::make_shared<torch::jit::Graph>();
3638
torch::jit::parseIR(graph, &*g);
3739

3840
auto in = at::randint(1, 10, {2, 3, 3}, {at::kCUDA});
41+
auto in_add = at::randint(1, 10, {2, 3, 3}, {at::kCUDA});
42+
43+
auto jit_in = at::clone(in);
44+
auto jit_in_add = at::clone(in_add);
3945

4046
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
41-
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
47+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in, jit_in_add});
48+
49+
auto trt_in = at::clone(jit_in);
50+
auto trt_in_add = at::clone(jit_in_add);
4251

4352
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
44-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
53+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in, trt_in_add});
4554

4655
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
4756
}

0 commit comments

Comments
 (0)