@@ -31,6 +31,59 @@ TEST(Converters, ATenSelectIntConvertsCorrectly) {
31
31
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
32
32
}
33
33
34
+ TEST (Converters, ATenSelectIntDimIsOneConvertsCorrectly) {
35
+ const auto graph = R"IR(
36
+ graph(%0 : Tensor):
37
+ %2 : int = prim::Constant[value=1]()
38
+ %3 : int = prim::Constant[value=0]()
39
+ %4 : Tensor = aten::select(%0, %2, %3)
40
+ return (%4))IR" ;
41
+
42
+ auto g = std::make_shared<torch::jit::Graph>();
43
+
44
+ torch::jit::parseIR (graph, &*g);
45
+
46
+ auto in = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
47
+
48
+ auto jit_in = at::clone (in);
49
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
50
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
51
+
52
+ auto trt_in = at::clone (in);
53
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
54
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
55
+
56
+ // In order to check whether shape match that we don't do reshape.
57
+ // E.g. x = at::randint(1, 10, {4, 4, 4}, {at::kCUDA}), then aten::select(x, 1, 0). We should get a tensor y with
58
+ // shape {4, 4} instead of a tensor with shape {4, 1, 4}.
59
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
60
+ }
61
+
62
+ TEST (Converters, ATenSelectIntDimNegativeConvertsCorrectly) {
63
+ const auto graph = R"IR(
64
+ graph(%0 : Tensor):
65
+ %2 : int = prim::Constant[value=-2]()
66
+ %3 : int = prim::Constant[value=0]()
67
+ %4 : Tensor = aten::select(%0, %2, %3)
68
+ return (%4))IR" ;
69
+
70
+ auto g = std::make_shared<torch::jit::Graph>();
71
+
72
+ torch::jit::parseIR (graph, &*g);
73
+
74
+ auto in = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
75
+
76
+ auto jit_in = at::clone (in);
77
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
78
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
79
+
80
+ auto trt_in = at::clone (in);
81
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
82
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
83
+
84
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
85
+ }
86
+
34
87
TEST (Converters, ATenSelectIntTwiceConvertsCorrectly) {
35
88
const auto graph = R"IR(
36
89
graph(%0 : Tensor):
0 commit comments