@@ -203,6 +203,170 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
203
203
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
204
204
}
205
205
206
+ TEST (Converters, ATenConvolution3dConvertsCorrectly) {
207
+ const auto graph = R"IR(
208
+ graph(%0 : Tensor,
209
+ %1 : Float(32:81, 3:27, 3:9, 3:3, 3:1),
210
+ %2 : Float(32:1)):
211
+ %sv : int = prim::Constant[value=1]()
212
+ %s : int[] = prim::ListConstruct(%sv, %sv, %sv)
213
+ %pv : int = prim::Constant[value=0]()
214
+ %p : int[] = prim::ListConstruct(%pv, %pv, %pv)
215
+ %transposed : bool = prim::Constant[value=0]()
216
+ %opv : int = prim::Constant[value=0]()
217
+ %op : int[] = prim::ListConstruct(%opv, %opv, %opv)
218
+ %g : int = prim::Constant[value=1]()
219
+ %fb : bool = prim::Constant[value=0]()
220
+ %out : Tensor = aten::_convolution(%0, %1, %2, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
221
+ return (%out))IR" ;
222
+
223
+ auto g = std::make_shared<torch::jit::Graph>();
224
+ torch::jit::parseIR (graph, &*g);
225
+
226
+ auto in = at::randint (1 , 10 , {1 , 3 , 5 , 5 , 5 }, {at::kCUDA });
227
+ auto w = at::randint (1 , 10 , {32 , 3 , 3 , 3 , 3 }, {at::kCUDA });
228
+ auto b = at::randint (1 , 10 , {32 }, {at::kCUDA });
229
+
230
+ auto jit_in = at::clone (in);
231
+ auto jit_w = at::clone (w);
232
+ auto jit_b = at::clone (b);
233
+
234
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
235
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
236
+
237
+ auto trt_in = at::clone (in);
238
+ auto trt_w = at::clone (w);
239
+ auto trt_b = at::clone (b);
240
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
241
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
242
+
243
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
244
+
245
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
246
+ }
247
+
248
+ TEST (Converters, ATenConvolution3dNoBiasConvertsCorrectly) {
249
+ const auto graph = R"IR(
250
+ graph(%0 : Tensor,
251
+ %1 : Float(32:81, 3:27, 3:9, 3:3, 3:1)):
252
+ %bias : None = prim::Constant()
253
+ %sv : int = prim::Constant[value=1]()
254
+ %s : int[] = prim::ListConstruct(%sv, %sv, %sv)
255
+ %pv : int = prim::Constant[value=0]()
256
+ %p : int[] = prim::ListConstruct(%pv, %pv, %pv)
257
+ %transposed : bool = prim::Constant[value=0]()
258
+ %opv : int = prim::Constant[value=0]()
259
+ %op : int[] = prim::ListConstruct(%opv, %opv, %opv)
260
+ %g : int = prim::Constant[value=1]()
261
+ %fb : bool = prim::Constant[value=0]()
262
+ %out : Tensor = aten::_convolution(%0, %1, %bias, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
263
+ return (%out))IR" ;
264
+
265
+ auto g = std::make_shared<torch::jit::Graph>();
266
+ torch::jit::parseIR (graph, &*g);
267
+
268
+ auto in = at::randint (1 , 2 , {1 , 3 , 5 , 5 , 5 }, {at::kCUDA });
269
+ auto w = at::randint (1 , 2 , {32 , 3 , 3 , 3 , 3 }, {at::kCUDA });
270
+
271
+ auto jit_in = at::clone (in);
272
+ auto jit_w = at::clone (w);
273
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w});
274
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
275
+
276
+ auto trt_in = at::clone (in);
277
+ auto trt_w = at::clone (w);
278
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w});
279
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
280
+
281
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
282
+
283
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
284
+ }
285
+
286
+ TEST (Converters, ATenConvolution3dWithPaddingConvertsCorrectly) {
287
+ const auto graph = R"IR(
288
+ graph(%0 : Tensor,
289
+ %1 : Float(32:81, 3:27, 3:9, 3:3, 3:1),
290
+ %2 : Float(32:1)):
291
+ %sv : int = prim::Constant[value=1]()
292
+ %s : int[] = prim::ListConstruct(%sv, %sv, %sv)
293
+ %pv : int = prim::Constant[value=1]()
294
+ %p : int[] = prim::ListConstruct(%pv, %pv, %pv)
295
+ %transposed : bool = prim::Constant[value=0]()
296
+ %opv : int = prim::Constant[value=0]()
297
+ %op : int[] = prim::ListConstruct(%opv, %opv, %opv)
298
+ %g : int = prim::Constant[value=1]()
299
+ %fb : bool = prim::Constant[value=0]()
300
+ %out : Tensor = aten::_convolution(%0, %1, %2, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
301
+ return (%out))IR" ;
302
+
303
+ auto g = std::make_shared<torch::jit::Graph>();
304
+ torch::jit::parseIR (graph, &*g);
305
+
306
+ auto in = at::randint (1 , 10 , {1 , 3 , 5 , 5 , 5 }, {at::kCUDA });
307
+ auto w = at::randint (1 , 10 , {32 , 3 , 3 , 3 , 3 }, {at::kCUDA });
308
+ auto b = at::randint (1 , 10 , {32 }, {at::kCUDA });
309
+
310
+ auto jit_in = at::clone (in);
311
+ auto jit_w = at::clone (w);
312
+ auto jit_b = at::clone (b);
313
+
314
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
315
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
316
+
317
+ auto trt_in = at::clone (in);
318
+ auto trt_w = at::clone (w);
319
+ auto trt_b = at::clone (b);
320
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
321
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
322
+
323
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
324
+
325
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
326
+ }
327
+
328
+ TEST (Converters, ATenConvolution3dWithStrideDilationConvertsCorrectly) {
329
+ const auto graph = R"IR(
330
+ graph(%0 : Tensor,
331
+ %1 : Float(32:81, 3:27, 3:9, 3:3, 3:1),
332
+ %2 : Float(32:1)):
333
+ %sv : int = prim::Constant[value=2]()
334
+ %s : int[] = prim::ListConstruct(%sv, %sv, %sv)
335
+ %pv : int = prim::Constant[value=1]()
336
+ %p : int[] = prim::ListConstruct(%pv, %pv, %pv)
337
+ %transposed : bool = prim::Constant[value=0]()
338
+ %opv : int = prim::Constant[value=0]()
339
+ %op : int[] = prim::ListConstruct(%opv, %opv, %opv)
340
+ %g : int = prim::Constant[value=1]()
341
+ %fb : bool = prim::Constant[value=0]()
342
+ %out : Tensor = aten::_convolution(%0, %1, %2, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
343
+ return (%out))IR" ;
344
+
345
+ auto g = std::make_shared<torch::jit::Graph>();
346
+ torch::jit::parseIR (graph, &*g);
347
+
348
+ auto in = at::randint (1 , 10 , {1 , 3 , 5 , 5 , 5 }, {at::kCUDA });
349
+ auto w = at::randint (1 , 10 , {32 , 3 , 3 , 3 , 3 }, {at::kCUDA });
350
+ auto b = at::randint (1 , 10 , {32 }, {at::kCUDA });
351
+
352
+ auto jit_in = at::clone (in);
353
+ auto jit_w = at::clone (w);
354
+ auto jit_b = at::clone (b);
355
+
356
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
357
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
358
+
359
+ auto trt_in = at::clone (in);
360
+ auto trt_w = at::clone (w);
361
+ auto trt_b = at::clone (b);
362
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
363
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
364
+
365
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
366
+
367
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
368
+ }
369
+
206
370
TEST (Converters, ATenConvTransposeConvertsCorrectly) {
207
371
const auto graph = R"IR(
208
372
graph(%0 : Tensor,
0 commit comments