@@ -30,3 +30,28 @@ TEST(Converters, ATenTopKConvertsCorrectly) {
30
30
ASSERT_TRUE (
31
31
torch_tensorrt::tests::util::almostEqual (jit_results[1 ], trt_results[1 ].reshape_as (jit_results[1 ]), 2e-6 ));
32
32
}
33
+
34
+ TEST (Converters, ATenMaxDimConvertsCorrectly) {
35
+ const auto graph = R"IR(
36
+ graph(%x.1 : Tensor):
37
+ %2 : int = prim::Constant[value=0]()
38
+ %3 : bool = prim::Constant[value=0]()
39
+ %4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
40
+ return (%4, %5))IR" ;
41
+
42
+ auto g = std::make_shared<torch::jit::Graph>();
43
+ torch::jit::parseIR (graph, g.get ());
44
+
45
+ auto in = at::rand ({2 , 3 , 5 , 5 }, {at::kCUDA });
46
+
47
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
48
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
49
+
50
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
51
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
52
+
53
+ ASSERT_TRUE (
54
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
55
+ ASSERT_TRUE (
56
+ torch_tensorrt::tests::util::almostEqual (jit_results[1 ], trt_results[1 ].reshape_as (jit_results[1 ]), 2e-6 ));
57
+ }
0 commit comments