@@ -22,5 +22,35 @@ TEST(Converters, ATenSqueezeConvertsCorrectly) {
22
22
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
23
23
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
24
24
25
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
26
+ }
27
+
28
+ TEST (Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) {
29
+ const auto graph = R"IR(
30
+ graph(%0 : Tensor, %1 : Tensor):
31
+ %2 : int = prim::Constant[value=1]()
32
+ %2.1 : Tensor = aten::add(%0, %1, %2)
33
+ %3 : Tensor = aten::squeeze(%2.1, %2)
34
+ %4 : Tensor = aten::add(%3, %1, %2)
35
+ return (%4))IR" ;
36
+
37
+ auto g = std::make_shared<torch::jit::Graph>();
38
+ torch::jit::parseIR (graph, &*g);
39
+
40
+ auto in = at::randint (1 , 10 , {2 , 3 , 3 }, {at::kCUDA });
41
+ auto in_add = at::randint (1 , 10 , {2 , 3 , 3 }, {at::kCUDA });
42
+
43
+ auto jit_in = at::clone (in);
44
+ auto jit_in_add = at::clone (in_add);
45
+
46
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
47
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in, jit_in_add});
48
+
49
+ auto trt_in = at::clone (jit_in);
50
+ auto trt_in_add = at::clone (jit_in_add);
51
+
52
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
53
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in, trt_in_add});
54
+
25
55
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
26
56
}
0 commit comments