5
5
#include " torch/csrc/jit/ir/irparser.h"
6
6
7
7
TEST (Converters, ATenStackPureTensorConvertsCorrectly) {
8
+ auto TestATenStackPureTensorConvertsCorrectly = [](const std::string& graph) {
9
+ auto g = std::make_shared<torch::jit::Graph>();
10
+ torch::jit::parseIR (graph, g.get ());
11
+
12
+ auto in1 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
13
+ auto in2 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
14
+
15
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
16
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
17
+
18
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
19
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
20
+
21
+ ASSERT_TRUE (
22
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
23
+ };
8
24
const auto graph = R"IR(
9
25
graph(%0 : Tensor,
10
26
%1 : Tensor):
11
27
%2 : Tensor[] = prim::ListConstruct(%0, %1)
12
28
%3 : int = prim::Constant[value=3]()
13
29
%4 : Tensor = aten::stack(%2, %3)
14
30
return (%4))IR" ;
31
+ const auto graph2 = R"IR(
32
+ graph(%0 : Tensor,
33
+ %1 : Tensor):
34
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
35
+ %3 : int = prim::Constant[value=-1]()
36
+ %4 : Tensor = aten::stack(%2, %3)
37
+ return (%4))IR" ;
15
38
16
- auto g = std::make_shared<torch::jit::Graph>();
17
- torch::jit::parseIR (graph, g.get ());
39
+ TestATenStackPureTensorConvertsCorrectly (graph);
40
+ TestATenStackPureTensorConvertsCorrectly (graph2);
41
+ }
18
42
19
- auto in1 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
20
- auto in2 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
43
+ TEST (Converters, ATenStackDiffTensorConvertsCorrectly) {
44
+ auto TestATenStackDiffTensorConvertsCorrectly = [](const std::string& graph) {
45
+ auto g = std::make_shared<torch::jit::Graph>();
46
+ torch::jit::parseIR (graph, g.get ());
21
47
22
- auto params = torch_tensorrt::core::ir::get_static_params (g-> inputs (), { });
23
- auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params , {in1, in2 });
48
+ auto in1 = at::randint ( 1 , 10 , { 4 , 4 , 4 }, {at:: kCUDA });
49
+ auto in2 = at::randint ( 1 , 10 , {4 , 4 , 4 }, {at:: kCUDA });
24
50
25
- params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
26
- auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2 });
51
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {in2 });
52
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1});
27
53
28
- ASSERT_TRUE (
29
- torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
30
- }
54
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {in2});
55
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1});
31
56
32
- TEST (Converters, ATenStackDiffTensorConvertsCorrectly) {
57
+ ASSERT_TRUE (
58
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
59
+ };
33
60
const auto graph = R"IR(
34
61
graph(%0 : Tensor,
35
62
%1 : Float(4, 4, 4, strides=[16, 4, 1])):
36
63
%2 : Tensor[] = prim::ListConstruct(%0, %1)
37
64
%3 : int = prim::Constant[value=1]()
38
65
%4 : Tensor = aten::stack(%2, %3)
39
66
return (%4))IR" ;
40
-
41
- auto g = std::make_shared<torch::jit::Graph>();
42
- torch::jit::parseIR (graph, g.get ());
43
-
44
- auto in1 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
45
- auto in2 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
46
-
47
- auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {in2});
48
- auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1});
49
-
50
- params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {in2});
51
- auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1});
52
-
53
- ASSERT_TRUE (
54
- torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
55
- }
67
+ const auto graph2 = R"IR(
68
+ graph(%0 : Tensor,
69
+ %1 : Float(4, 4, 4, strides=[16, 4, 1])):
70
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
71
+ %3 : int = prim::Constant[value=-1]()
72
+ %4 : Tensor = aten::stack(%2, %3)
73
+ return (%4))IR" ;
74
+ TestATenStackDiffTensorConvertsCorrectly (graph);
75
+ TestATenStackDiffTensorConvertsCorrectly (graph2);
76
+ }
0 commit comments