@@ -135,6 +135,36 @@ TEST(Converters, ATenBoolToINT32TensorConvertsCorrectly) {
135
135
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
136
136
}
137
137
138
+
139
+ TEST (Converters, ATenToSingleConvertsCorrectly) {
140
+ const auto graph = R"IR(
141
+ graph(%y.1 : Tensor):
142
+ %4 : int = prim::Constant[value=6]()
143
+ %5 : bool = prim::Constant[value=0]()
144
+ %6 : None = prim::Constant()
145
+ %y0.1 : Tensor = aten::to(%y.1, %4, %5, %5, %6)
146
+ return (%y0.1))IR" ;
147
+
148
+ auto g = std::make_shared<torch::jit::Graph>();
149
+
150
+ torch::jit::parseIR (graph, &*g);
151
+
152
+ auto in = at::randint (1 , 10 , {3 }, {at::kCUDA });
153
+
154
+ auto jit_in = at::clone (in);
155
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
156
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
157
+
158
+ auto trt_in = at::clone (in);
159
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
160
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
161
+
162
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
163
+ ASSERT_TRUE (jit_results[0 ].scalar_type () == trt.scalar_type ());
164
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
165
+ }
166
+
167
+
138
168
TEST (Converters, ATenTypeAsConvertsCorrectly) {
139
169
const auto graph = R"IR(
140
170
graph(%0 : Tensor,
0 commit comments