Skip to content

Commit 535d1a5

Browse files
committed
chore: Add test case for torch max dim
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 0f48534 commit 535d1a5

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/core/conversion/converters/test_topk.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,28 @@ TEST(Converters, ATenTopKConvertsCorrectly) {
3030
ASSERT_TRUE(
3131
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
3232
}
33+
34+
TEST(Converters, ATenMaxDimConvertsCorrectly) {
35+
const auto graph = R"IR(
36+
graph(%x.1 : Tensor):
37+
%2 : int = prim::Constant[value=0]()
38+
%3 : bool = prim::Constant[value=0]()
39+
%4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
40+
return (%4, %5))IR";
41+
42+
auto g = std::make_shared<torch::jit::Graph>();
43+
torch::jit::parseIR(graph, g.get());
44+
45+
auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});
46+
47+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
48+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
49+
50+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
51+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
52+
53+
ASSERT_TRUE(
54+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
55+
ASSERT_TRUE(
56+
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
57+
}

0 commit comments

Comments
 (0)