Skip to content

Commit c869393

Browse files
committed
test stack fix when dim is -1
Signed-off-by: hongwei03 <[email protected]>
1 parent 7b6733b commit c869393

File tree

1 file changed

+49
-28
lines changed

1 file changed

+49
-28
lines changed

tests/core/conversion/converters/test_stack.cpp

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,72 @@
55
#include "torch/csrc/jit/ir/irparser.h"
66

77
TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
8+
auto TestATenStackPureTensorConvertsCorrectly = [](const std::string& graph) {
9+
auto g = std::make_shared<torch::jit::Graph>();
10+
torch::jit::parseIR(graph, g.get());
11+
12+
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
13+
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
14+
15+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
16+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
17+
18+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
19+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
20+
21+
ASSERT_TRUE(
22+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
23+
};
824
const auto graph = R"IR(
925
graph(%0 : Tensor,
1026
%1 : Tensor):
1127
%2 : Tensor[] = prim::ListConstruct(%0, %1)
1228
%3 : int = prim::Constant[value=3]()
1329
%4 : Tensor = aten::stack(%2, %3)
1430
return (%4))IR";
31+
const auto graph2 = R"IR(
32+
graph(%0 : Tensor,
33+
%1 : Tensor):
34+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
35+
%3 : int = prim::Constant[value=-1]()
36+
%4 : Tensor = aten::stack(%2, %3)
37+
return (%4))IR";
1538

16-
auto g = std::make_shared<torch::jit::Graph>();
17-
torch::jit::parseIR(graph, g.get());
39+
TestATenStackPureTensorConvertsCorrectly(graph);
40+
TestATenStackPureTensorConvertsCorrectly(graph2);
41+
}
1842

19-
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
20-
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
43+
TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
44+
auto TestATenStackDiffTensorConvertsCorrectly = [](const std::string& graph) {
45+
auto g = std::make_shared<torch::jit::Graph>();
46+
torch::jit::parseIR(graph, g.get());
2147

22-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
23-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
48+
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
49+
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
2450

25-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
26-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
51+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
52+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1});
2753

28-
ASSERT_TRUE(
29-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
30-
}
54+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
55+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});
3156

32-
TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
57+
ASSERT_TRUE(
58+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
59+
};
3360
const auto graph = R"IR(
3461
graph(%0 : Tensor,
3562
%1 : Float(4, 4, 4, strides=[16, 4, 1])):
3663
%2 : Tensor[] = prim::ListConstruct(%0, %1)
3764
%3 : int = prim::Constant[value=1]()
3865
%4 : Tensor = aten::stack(%2, %3)
3966
return (%4))IR";
40-
41-
auto g = std::make_shared<torch::jit::Graph>();
42-
torch::jit::parseIR(graph, g.get());
43-
44-
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
45-
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
46-
47-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
48-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1});
49-
50-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
51-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});
52-
53-
ASSERT_TRUE(
54-
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
55-
}
67+
const auto graph2 = R"IR(
68+
graph(%0 : Tensor,
69+
%1 : Float(4, 4, 4, strides=[16, 4, 1])):
70+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
71+
%3 : int = prim::Constant[value=-1]()
72+
%4 : Tensor = aten::stack(%2, %3)
73+
return (%4))IR";
74+
TestATenStackDiffTensorConvertsCorrectly(graph);
75+
TestATenStackDiffTensorConvertsCorrectly(graph2);
76+
}

0 commit comments

Comments
 (0)