1
+ #include < string>
2
+ #include " gtest/gtest.h"
3
+ #include " torch/csrc/jit/ir/irparser.h"
4
+ #include " tests/util/util.h"
5
+ #include " core/compiler.h"
6
+
7
+ TEST (Converters, ATenStackPureTensorConvertsCorrectly) {
8
+ const auto graph = R"IR(
9
+ graph(%0 : Tensor,
10
+ %1 : Tensor):
11
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
12
+ %3 : int = prim::Constant[value=3]()
13
+ %4 : Tensor = aten::stack(%2, %3)
14
+ return (%4))IR" ;
15
+
16
+ auto g = std::make_shared<torch::jit::Graph>();
17
+ torch::jit::parseIR (graph, &*g);
18
+
19
+ auto in1 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
20
+ auto in2 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
21
+
22
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
23
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1, in2});
24
+
25
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
26
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1, in2});
27
+
28
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
29
+ }
30
+
31
+ TEST (Converters, ATenStackDiffTensorConvertsCorrectly) {
32
+ const auto graph = R"IR(
33
+ graph(%0 : Tensor,
34
+ %1 : Float(4, 4, 4)):
35
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
36
+ %3 : int = prim::Constant[value=1]()
37
+ %4 : Tensor = aten::stack(%2, %3)
38
+ return (%4))IR" ;
39
+
40
+ auto g = std::make_shared<torch::jit::Graph>();
41
+ torch::jit::parseIR (graph, &*g);
42
+
43
+ auto in1 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
44
+ auto in2 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
45
+
46
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
47
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
48
+
49
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
50
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
51
+
52
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
53
+ }
0 commit comments