@@ -203,6 +203,170 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
203203 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
204204}
205205
206+ TEST (Converters, ATenConvolution3dConvertsCorrectly) {
207+ const auto graph = R"IR(
208+ graph(%0 : Tensor,
209+ %1 : Float(32:81, 3:27, 3:9, 3:3, 3:1),
210+ %2 : Float(32:1)):
211+ %sv : int = prim::Constant[value=1]()
212+ %s : int[] = prim::ListConstruct(%sv, %sv, %sv)
213+ %pv : int = prim::Constant[value=0]()
214+ %p : int[] = prim::ListConstruct(%pv, %pv, %pv)
215+ %transposed : bool = prim::Constant[value=0]()
216+ %opv : int = prim::Constant[value=0]()
217+ %op : int[] = prim::ListConstruct(%opv, %opv, %opv)
218+ %g : int = prim::Constant[value=1]()
219+ %fb : bool = prim::Constant[value=0]()
220+ %out : Tensor = aten::_convolution(%0, %1, %2, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
221+ return (%out))IR" ;
222+
223+ auto g = std::make_shared<torch::jit::Graph>();
224+ torch::jit::parseIR (graph, &*g);
225+
226+ auto in = at::randint (1 , 10 , {1 , 3 , 5 , 5 , 5 }, {at::kCUDA });
227+ auto w = at::randint (1 , 10 , {32 , 3 , 3 , 3 , 3 }, {at::kCUDA });
228+ auto b = at::randint (1 , 10 , {32 }, {at::kCUDA });
229+
230+ auto jit_in = at::clone (in);
231+ auto jit_w = at::clone (w);
232+ auto jit_b = at::clone (b);
233+
234+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
235+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
236+
237+ auto trt_in = at::clone (in);
238+ auto trt_w = at::clone (w);
239+ auto trt_b = at::clone (b);
240+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
241+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
242+
243+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
244+
245+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
246+ }
247+
248+ TEST (Converters, ATenConvolution3dNoBiasConvertsCorrectly) {
249+ const auto graph = R"IR(
250+ graph(%0 : Tensor,
251+ %1 : Float(32:81, 3:27, 3:9, 3:3, 3:1)):
252+ %bias : None = prim::Constant()
253+ %sv : int = prim::Constant[value=1]()
254+ %s : int[] = prim::ListConstruct(%sv, %sv, %sv)
255+ %pv : int = prim::Constant[value=0]()
256+ %p : int[] = prim::ListConstruct(%pv, %pv, %pv)
257+ %transposed : bool = prim::Constant[value=0]()
258+ %opv : int = prim::Constant[value=0]()
259+ %op : int[] = prim::ListConstruct(%opv, %opv, %opv)
260+ %g : int = prim::Constant[value=1]()
261+ %fb : bool = prim::Constant[value=0]()
262+ %out : Tensor = aten::_convolution(%0, %1, %bias, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
263+ return (%out))IR" ;
264+
265+ auto g = std::make_shared<torch::jit::Graph>();
266+ torch::jit::parseIR (graph, &*g);
267+
268+ auto in = at::randint (1 , 2 , {1 , 3 , 5 , 5 , 5 }, {at::kCUDA });
269+ auto w = at::randint (1 , 2 , {32 , 3 , 3 , 3 , 3 }, {at::kCUDA });
270+
271+ auto jit_in = at::clone (in);
272+ auto jit_w = at::clone (w);
273+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w});
274+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
275+
276+ auto trt_in = at::clone (in);
277+ auto trt_w = at::clone (w);
278+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w});
279+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
280+
281+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
282+
283+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
284+ }
285+
286+ TEST (Converters, ATenConvolution3dWithPaddingConvertsCorrectly) {
287+ const auto graph = R"IR(
288+ graph(%0 : Tensor,
289+ %1 : Float(32:81, 3:27, 3:9, 3:3, 3:1),
290+ %2 : Float(32:1)):
291+ %sv : int = prim::Constant[value=1]()
292+ %s : int[] = prim::ListConstruct(%sv, %sv, %sv)
293+ %pv : int = prim::Constant[value=1]()
294+ %p : int[] = prim::ListConstruct(%pv, %pv, %pv)
295+ %transposed : bool = prim::Constant[value=0]()
296+ %opv : int = prim::Constant[value=0]()
297+ %op : int[] = prim::ListConstruct(%opv, %opv, %opv)
298+ %g : int = prim::Constant[value=1]()
299+ %fb : bool = prim::Constant[value=0]()
300+ %out : Tensor = aten::_convolution(%0, %1, %2, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
301+ return (%out))IR" ;
302+
303+ auto g = std::make_shared<torch::jit::Graph>();
304+ torch::jit::parseIR (graph, &*g);
305+
306+ auto in = at::randint (1 , 10 , {1 , 3 , 5 , 5 , 5 }, {at::kCUDA });
307+ auto w = at::randint (1 , 10 , {32 , 3 , 3 , 3 , 3 }, {at::kCUDA });
308+ auto b = at::randint (1 , 10 , {32 }, {at::kCUDA });
309+
310+ auto jit_in = at::clone (in);
311+ auto jit_w = at::clone (w);
312+ auto jit_b = at::clone (b);
313+
314+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
315+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
316+
317+ auto trt_in = at::clone (in);
318+ auto trt_w = at::clone (w);
319+ auto trt_b = at::clone (b);
320+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
321+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
322+
323+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
324+
325+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
326+ }
327+
328+ TEST (Converters, ATenConvolution3dWithStrideDilationConvertsCorrectly) {
329+ const auto graph = R"IR(
330+ graph(%0 : Tensor,
331+ %1 : Float(32:81, 3:27, 3:9, 3:3, 3:1),
332+ %2 : Float(32:1)):
333+ %sv : int = prim::Constant[value=2]()
334+ %s : int[] = prim::ListConstruct(%sv, %sv, %sv)
335+ %pv : int = prim::Constant[value=1]()
336+ %p : int[] = prim::ListConstruct(%pv, %pv, %pv)
337+ %transposed : bool = prim::Constant[value=0]()
338+ %opv : int = prim::Constant[value=0]()
339+ %op : int[] = prim::ListConstruct(%opv, %opv, %opv)
340+ %g : int = prim::Constant[value=1]()
341+ %fb : bool = prim::Constant[value=0]()
342+ %out : Tensor = aten::_convolution(%0, %1, %2, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
343+ return (%out))IR" ;
344+
345+ auto g = std::make_shared<torch::jit::Graph>();
346+ torch::jit::parseIR (graph, &*g);
347+
348+ auto in = at::randint (1 , 10 , {1 , 3 , 5 , 5 , 5 }, {at::kCUDA });
349+ auto w = at::randint (1 , 10 , {32 , 3 , 3 , 3 , 3 }, {at::kCUDA });
350+ auto b = at::randint (1 , 10 , {32 }, {at::kCUDA });
351+
352+ auto jit_in = at::clone (in);
353+ auto jit_w = at::clone (w);
354+ auto jit_b = at::clone (b);
355+
356+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
357+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
358+
359+ auto trt_in = at::clone (in);
360+ auto trt_w = at::clone (w);
361+ auto trt_b = at::clone (b);
362+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
363+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
364+
365+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
366+
367+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
368+ }
369+
206370TEST (Converters, ATenConvTransposeConvertsCorrectly) {
207371 const auto graph = R"IR(
208372 graph(%0 : Tensor,
0 commit comments