Skip to content

Commit 759664d

Browse files
authored
Merge pull request #943 from p517332051/netease_fix_stack_bug
fix bug stack when dim -1
2 parents ba9f730 + c869393 commit 759664d

File tree

2 files changed

+57
-29
lines changed

2 files changed

+57
-29
lines changed

core/conversion/converters/impl/stack.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@ auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patt
1919
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
2020
auto in = args[0].IValue()->toListRef();
2121
auto dim = args[1].unwrapToInt();
22+
if (-1 == dim) {
23+
auto first_in = in[0];
24+
if (first_in.isTensor()) {
25+
dim = first_in.toTensor().ndimension();
26+
} else {
27+
dim = first_in.toCustomClass<TensorContainer>()->tensor()->getDimensions().nbDims;
28+
}
29+
}
2230

2331
std::vector<nvinfer1::ITensor*> tensors;
24-
2532
for (auto t : in) {
2633
nvinfer1::ITensor* itensor;
2734

tests/core/conversion/converters/test_stack.cpp

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,72 @@
55
#include "torch/csrc/jit/ir/irparser.h"
66

77
TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
8+
auto TestATenStackPureTensorConvertsCorrectly = [](const std::string& graph) {
9+
auto g = std::make_shared<torch::jit::Graph>();
10+
torch::jit::parseIR(graph, g.get());
11+
12+
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
13+
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
14+
15+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
16+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
17+
18+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
19+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
20+
21+
ASSERT_TRUE(
22+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
23+
};
824
const auto graph = R"IR(
925
graph(%0 : Tensor,
1026
%1 : Tensor):
1127
%2 : Tensor[] = prim::ListConstruct(%0, %1)
1228
%3 : int = prim::Constant[value=3]()
1329
%4 : Tensor = aten::stack(%2, %3)
1430
return (%4))IR";
31+
const auto graph2 = R"IR(
32+
graph(%0 : Tensor,
33+
%1 : Tensor):
34+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
35+
%3 : int = prim::Constant[value=-1]()
36+
%4 : Tensor = aten::stack(%2, %3)
37+
return (%4))IR";
1538

16-
auto g = std::make_shared<torch::jit::Graph>();
17-
torch::jit::parseIR(graph, g.get());
39+
TestATenStackPureTensorConvertsCorrectly(graph);
40+
TestATenStackPureTensorConvertsCorrectly(graph2);
41+
}
1842

19-
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
20-
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
43+
TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
44+
auto TestATenStackDiffTensorConvertsCorrectly = [](const std::string& graph) {
45+
auto g = std::make_shared<torch::jit::Graph>();
46+
torch::jit::parseIR(graph, g.get());
2147

22-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
23-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
48+
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
49+
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
2450

25-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
26-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
51+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
52+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1});
2753

28-
ASSERT_TRUE(
29-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
30-
}
54+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
55+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});
3156

32-
TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
57+
ASSERT_TRUE(
58+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
59+
};
3360
const auto graph = R"IR(
3461
graph(%0 : Tensor,
3562
%1 : Float(4, 4, 4, strides=[16, 4, 1])):
3663
%2 : Tensor[] = prim::ListConstruct(%0, %1)
3764
%3 : int = prim::Constant[value=1]()
3865
%4 : Tensor = aten::stack(%2, %3)
3966
return (%4))IR";
40-
41-
auto g = std::make_shared<torch::jit::Graph>();
42-
torch::jit::parseIR(graph, g.get());
43-
44-
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
45-
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
46-
47-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
48-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1});
49-
50-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
51-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});
52-
53-
ASSERT_TRUE(
54-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
55-
}
67+
const auto graph2 = R"IR(
68+
graph(%0 : Tensor,
69+
%1 : Float(4, 4, 4, strides=[16, 4, 1])):
70+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
71+
%3 : int = prim::Constant[value=-1]()
72+
%4 : Tensor = aten::stack(%2, %3)
73+
return (%4))IR";
74+
TestATenStackDiffTensorConvertsCorrectly(graph);
75+
TestATenStackDiffTensorConvertsCorrectly(graph2);
76+
}

0 commit comments

Comments
 (0)