@@ -22,5 +22,26 @@ TEST(Converters, ATenUnsqueezeConvertsCorrectly) {
22
22
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
23
23
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
24
24
25
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
26
+ }
27
+
28
+ TEST (Converters, ATenUnsqueezeNegativeDimConvertsCorrectly) {
29
+ const auto graph = R"IR(
30
+ graph(%0 : Tensor):
31
+ %1 : int = prim::Constant[value=-4]()
32
+ %2 : Tensor = aten::unsqueeze(%0, %1)
33
+ return (%2))IR" ;
34
+
35
+ auto g = std::make_shared<torch::jit::Graph>();
36
+ torch::jit::parseIR (graph, &*g);
37
+
38
+ auto in = at::randint (1 , 10 , {2 , 3 , 3 }, {at::kCUDA });
39
+
40
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
41
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
42
+
43
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
44
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
45
+
25
46
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
26
47
}
0 commit comments