@@ -532,65 +532,88 @@ TEST(Converters, ATenConvTransposeWithPaddingConvertsCorrectly) {
532
532
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
533
533
}
534
534
535
- // TEST(Converters, ATenConvolutionWithDialationConvertsCorrectly) {
536
- // const auto graph = R"IR(
537
- // graph(%0 : Tensor,
538
- // %1 : Float(8, 3, 5, 5),
539
- // %2 : Float(8)):
540
- // %3 : int = prim::Constant[value=1]()
541
- // %4 : int = prim::Constant[value=0]()
542
- // %5 : int = prim::Constant[value=2]()
543
- // %6 : int = prim::Constant[value=0]()
544
- // %7 : bool = prim::Constant[value=0]()
545
- // %8 : int[] = prim::ListConstruct(%3, %3)
546
- // %9 : int[] = prim::ListConstruct(%4, %4)
547
- // %10 : int[] = prim::ListConstruct(%5, %5)
548
- // %11 : int[] = prim::ListConstruct(%6, %6)
549
- // %12 : int = prim::Constant[value=1]()
550
- // %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11,
551
- // %12, %7, %7, %7) return (%13))IR";
552
-
553
- // conv_test_helper(graph);
554
- // }
555
-
556
- // TEST(Converters, ATenConvolutionWithPostPaddingConvertsCorrectly) {
557
- // const auto graph = R"IR(
558
- // graph(%0 : Tensor,
559
- // %1 : Float(8, 3, 5, 5),
560
- // %2 : Float(8)):
561
- // %3 : int = prim::Constant[value=1]()
562
- // %4 : int = prim::Constant[value=0]()
563
- // %5 : int = prim::Constant[value=1]()
564
- // %6 : int = prim::Constant[value=2]()
565
- // %7 : bool = prim::Constant[value=0]()
566
- // %8 : int[] = prim::ListConstruct(%3, %3)
567
- // %9 : int[] = prim::ListConstruct(%4, %4)
568
- // %10 : int[] = prim::ListConstruct(%5, %5)
569
- // %11 : int[] = prim::ListConstruct(%6, %6)
570
- // %12 : int = prim::Constant[value=1]()
571
- // %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11,
572
- // %12, %7, %7, %7) return (%13))IR";
573
-
574
- // conv_test_helper(graph);
575
- // }
576
-
577
- // TEST(Converters, ATenConvolutionWithGroupConvertsCorrectly) {
578
- // const auto graph = R"IR(
579
- // graph(%0 : Tensor,
580
- // %1 : Float(8, 3, 5, 5),
581
- // %2 : Float(8)):
582
- // %3 : int = prim::Constant[value=1]()
583
- // %4 : int = prim::Constant[value=0]()
584
- // %5 : int = prim::Constant[value=1]()
585
- // %6 : int = prim::Constant[value=0]()
586
- // %7 : bool = prim::Constant[value=0]()
587
- // %8 : int[] = prim::ListConstruct(%3, %3)
588
- // %9 : int[] = prim::ListConstruct(%4, %4)
589
- // %10 : int[] = prim::ListConstruct(%5, %5)
590
- // %11 : int[] = prim::ListConstruct(%6, %6)
591
- // %12 : int = prim::Constant[value=2]()
592
- // %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11,
593
- // %12, %7, %7, %7) return (%13))IR";
594
-
595
- // conv_test_helper(graph);
596
- // }
535
+ TEST (Converters, ATenConvolutionWithGroupConvertsCorrectly) {
536
+ const auto graph = R"IR(
537
+ graph(%0 : Tensor,
538
+ %1 : Float(8:48, 1:16, 2:4, 2:1),
539
+ %2 : Float(8:1)):
540
+ %3 : int = prim::Constant[value=1]()
541
+ %4 : int = prim::Constant[value=2]()
542
+ %5 : int = prim::Constant[value=1]()
543
+ %6 : int = prim::Constant[value=0]()
544
+ %7 : bool = prim::Constant[value=0]()
545
+ %8 : int[] = prim::ListConstruct(%3, %3)
546
+ %9 : int[] = prim::ListConstruct(%4, %4)
547
+ %10 : int[] = prim::ListConstruct(%5, %5)
548
+ %11 : int[] = prim::ListConstruct(%6, %6)
549
+ %12 : int = prim::Constant[value=4]()
550
+ %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
551
+ return (%13))IR" ;
552
+
553
+ auto g = std::make_shared<torch::jit::Graph>();
554
+ torch::jit::parseIR (graph, &*g);
555
+
556
+ auto in = at::randint (1 , 10 , {1 , 4 , 4 , 4 }, {at::kCUDA });
557
+ auto w = at::randint (1 , 10 , {8 , 1 , 2 , 2 }, {at::kCUDA });
558
+ auto b = at::randint (1 , 10 , {8 }, {at::kCUDA });
559
+
560
+ auto jit_in = at::clone (in);
561
+ auto jit_w = at::clone (w);
562
+ auto jit_b = at::clone (b);
563
+
564
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
565
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
566
+
567
+ auto trt_in = at::clone (in);
568
+ auto trt_w = at::clone (w);
569
+ auto trt_b = at::clone (b);
570
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
571
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
572
+
573
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
574
+
575
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
576
+ }
577
+
578
+ TEST (Converters, ATenConvTransposeWithGroupConvertsCorrectly) {
579
+ const auto graph = R"IR(
580
+ graph(%0 : Tensor,
581
+ %1 : Float(8:56, 4:16, 3:3, 3:1),
582
+ %2 : Float(16:1)):
583
+ %3 : int = prim::Constant[value=1]()
584
+ %4 : int = prim::Constant[value=1]()
585
+ %5 : int = prim::Constant[value=1]()
586
+ %6 : int = prim::Constant[value=0]()
587
+ %7 : bool = prim::Constant[value=1]()
588
+ %8 : int[] = prim::ListConstruct(%3, %3)
589
+ %9 : int[] = prim::ListConstruct(%4, %4)
590
+ %10 : int[] = prim::ListConstruct(%5, %5)
591
+ %11 : int[] = prim::ListConstruct(%6, %6)
592
+ %12 : int = prim::Constant[value=4]()
593
+ %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
594
+ return (%13))IR" ;
595
+
596
+ auto g = std::make_shared<torch::jit::Graph>();
597
+ torch::jit::parseIR (graph, &*g);
598
+
599
+ auto in = at::randint (1 , 10 , {1 , 8 , 5 , 5 }, {at::kCUDA });
600
+ auto w = at::randint (1 , 10 , {8 , 4 , 3 , 3 }, {at::kCUDA });
601
+ auto b = at::randint (1 , 10 , {16 }, {at::kCUDA });
602
+
603
+ auto jit_in = at::clone (in);
604
+ auto jit_w = at::clone (w);
605
+ auto jit_b = at::clone (b);
606
+
607
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {jit_w, jit_b});
608
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
609
+
610
+ auto trt_in = at::clone (in);
611
+ auto trt_w = at::clone (w);
612
+ auto trt_b = at::clone (b);
613
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_w, trt_b});
614
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
615
+
616
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
617
+
618
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
619
+ }
0 commit comments