@@ -27,21 +27,30 @@ TEST(Converters, ATenSqueezeConvertsCorrectly) {
27
27
28
28
TEST (Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) {
29
29
const auto graph = R"IR(
30
- graph(%0 : Tensor):
31
- %1 : int = prim::Constant[value=1]()
32
- %2 : Tensor = aten::squeeze(%0, %1)
33
- return (%2))IR" ;
30
+ graph(%0 : Tensor, %1 : Tensor):
31
+ %2 : int = prim::Constant[value=1]()
32
+ %2.1 : Tensor = aten::add(%0, %1, %2)
33
+ %3 : Tensor = aten::squeeze(%2.1, %2)
34
+ %4 : Tensor = aten::add(%3, %1, %2)
35
+ return (%4))IR" ;
34
36
35
37
auto g = std::make_shared<torch::jit::Graph>();
36
38
torch::jit::parseIR (graph, &*g);
37
39
38
40
auto in = at::randint (1 , 10 , {2 , 3 , 3 }, {at::kCUDA });
41
+ auto in_add = at::randint (1 , 10 , {2 , 3 , 3 }, {at::kCUDA });
42
+
43
+ auto jit_in = at::clone (in);
44
+ auto jit_in_add = at::clone (in_add);
39
45
40
46
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
41
- auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
47
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in, jit_in_add});
48
+
49
+ auto trt_in = at::clone (jit_in);
50
+ auto trt_in_add = at::clone (jit_in_add);
42
51
43
52
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
44
- auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in });
53
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in, trt_in_add });
45
54
46
55
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
47
56
}
0 commit comments