Skip to content

Commit d33ec82

Browse files
committed
test(aten::select.int): added test for aten::select.int
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent e3b0e53 commit d33ec82

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

tests/core/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ converter_test(
5959
name = "test_interpolate"
6060
)
6161

62+
converter_test(
63+
name = "test_select"
64+
)
65+
6266
test_suite(
6367
name = "test_converters",
6468
tests = [
@@ -74,6 +78,7 @@ test_suite(
7478
":test_softmax",
7579
":test_unary",
7680
":test_interpolate",
81+
":test_select"
7782
]
7883
)
7984

tests/core/converters/test_select.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%2 : int = prim::Constant[value=0]()
11+
%3 : int = prim::Constant[value=3]()
12+
%4 : Tensor = aten::select(%0, %2, %2)
13+
%5 : Tensor = aten::select(%4, %2, %3)
14+
return (%5))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
18+
torch::jit::parseIR(graph, &*g);
19+
20+
auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
21+
22+
auto jit_in = at::clone(in);
23+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
24+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
25+
26+
auto trt_in = at::clone(in);
27+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
28+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
29+
30+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
31+
32+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
33+
}

0 commit comments

Comments
 (0)