1
+ #include < string>
2
+ #include " gtest/gtest.h"
3
+ #include " torch/csrc/jit/ir/irparser.h"
4
+ #include " tests/util/util.h"
5
+ #include " core/compiler.h"
6
+
7
+ TEST (Converters, ATenSelectIntTwiceConvertsCorrectly) {
8
+ const auto graph = R"IR(
9
+ graph(%0 : Tensor):
10
+ %2 : int = prim::Constant[value=0]()
11
+ %3 : int = prim::Constant[value=3]()
12
+ %4 : Tensor = aten::select(%0, %2, %2)
13
+ %5 : Tensor = aten::select(%4, %2, %3)
14
+ return (%5))IR" ;
15
+
16
+ auto g = std::make_shared<torch::jit::Graph>();
17
+
18
+ torch::jit::parseIR (graph, &*g);
19
+
20
+ auto in = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
21
+
22
+ auto jit_in = at::clone (in);
23
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
24
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
25
+
26
+ auto trt_in = at::clone (in);
27
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
28
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
29
+
30
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
31
+
32
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
33
+ }
0 commit comments