@@ -570,6 +570,131 @@ TEST(Converters, ATenConvTransposeWithPaddingConvertsCorrectly) {
570
570
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
571
571
}
572
572
573
+ TEST (Converters, ATenConv1dTransposeWithPaddingOutPaddingConvertsCorrectly) {
574
+ const auto graph = R"IR(
575
+ graph(%0 : Tensor,
576
+ %1 : Float(4, 3, 3, strides=[9, 3, 1])):
577
+ %2 : None = prim::Constant()
578
+ %3 : int = prim::Constant[value=2]()
579
+ %4 : int = prim::Constant[value=1]()
580
+ %5 : int = prim::Constant[value=1]()
581
+ %6 : int = prim::Constant[value=1]()
582
+ %7 : bool = prim::Constant[value=1]()
583
+ %8 : int[] = prim::ListConstruct(%3)
584
+ %9 : int[] = prim::ListConstruct(%4)
585
+ %10 : int[] = prim::ListConstruct(%5)
586
+ %11 : int[] = prim::ListConstruct(%6)
587
+ %12 : int = prim::Constant[value=1]()
588
+ %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7, %7)
589
+ return (%13))IR" ;
590
+
591
+ auto g = std::make_shared<torch::jit::Graph>();
592
+ torch::jit::parseIR (graph, g.get ());
593
+
594
+ auto in = at::randint (1 , 2 , {1 , 3 , 3 }, {at::kCUDA });
595
+ auto w = at::randint (1 , 2 , {3 , 4 , 3 }, {at::kCUDA });
596
+
597
+ auto jit_in = at::clone (in);
598
+ auto jit_w = at::clone (w);
599
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_w});
600
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
601
+
602
+ auto trt_in = at::clone (in);
603
+ auto trt_w = at::clone (w);
604
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_w});
605
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
606
+
607
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
608
+
609
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
610
+ }
611
+
612
+ TEST (Converters, ATenConvTransposeWithPaddingOutPaddingConvertsCorrectly) {
613
+ const auto graph = R"IR(
614
+ graph(%0 : Tensor,
615
+ %1 : Float(4, 3, 4, 4, strides=[48, 16, 4, 1]),
616
+ %2 : Float(4)):
617
+ %3 : int = prim::Constant[value=2]()
618
+ %4 : int = prim::Constant[value=2]()
619
+ %5 : int = prim::Constant[value=1]()
620
+ %6 : int = prim::Constant[value=1]()
621
+ %7 : bool = prim::Constant[value=1]()
622
+ %8 : int[] = prim::ListConstruct(%3, %3)
623
+ %9 : int[] = prim::ListConstruct(%4, %4)
624
+ %10 : int[] = prim::ListConstruct(%5, %5)
625
+ %11 : int[] = prim::ListConstruct(%6, %6)
626
+ %12 : int = prim::Constant[value=1]()
627
+ %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7, %7)
628
+ return (%13))IR" ;
629
+
630
+ auto g = std::make_shared<torch::jit::Graph>();
631
+ torch::jit::parseIR (graph, g.get ());
632
+
633
+ auto in = at::randint (1 , 10 , {1 , 4 , 4 , 4 }, {at::kCUDA });
634
+ auto w = at::randint (1 , 10 , {4 , 3 , 2 , 2 }, {at::kCUDA });
635
+ auto b = at::randint (1 , 10 , {3 }, {at::kCUDA });
636
+
637
+ auto jit_in = at::clone (in);
638
+ auto jit_w = at::clone (w);
639
+ auto jit_b = at::clone (b);
640
+
641
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_w, jit_b});
642
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
643
+
644
+ auto trt_in = at::clone (in);
645
+ auto trt_w = at::clone (w);
646
+ auto trt_b = at::clone (b);
647
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_w, trt_b});
648
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
649
+
650
+ auto trt = trt_results[0 ];
651
+
652
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
653
+ }
654
+
655
+ TEST (Converters, ATenConvTransposeOutPaddingBiggerThanPaddingConvertsCorrectly) {
656
+ const auto graph = R"IR(
657
+ graph(%0 : Tensor,
658
+ %1 : Float(4, 3, 4, 4, strides=[48, 16, 4, 1]),
659
+ %2 : Float(4)):
660
+ %3 : int = prim::Constant[value=4]()
661
+ %4 : int = prim::Constant[value=2]()
662
+ %5 : int = prim::Constant[value=1]()
663
+ %6 : int = prim::Constant[value=3]()
664
+ %7 : bool = prim::Constant[value=1]()
665
+ %8 : int[] = prim::ListConstruct(%3, %3)
666
+ %9 : int[] = prim::ListConstruct(%4, %4)
667
+ %10 : int[] = prim::ListConstruct(%5, %5)
668
+ %11 : int[] = prim::ListConstruct(%6, %6)
669
+ %12 : int = prim::Constant[value=1]()
670
+ %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7, %7)
671
+ return (%13))IR" ;
672
+
673
+ auto g = std::make_shared<torch::jit::Graph>();
674
+ torch::jit::parseIR (graph, g.get ());
675
+
676
+ auto in = at::randint (1 , 10 , {1 , 4 , 4 , 4 }, {at::kCUDA });
677
+ auto w = at::randint (1 , 10 , {4 , 3 , 2 , 2 }, {at::kCUDA });
678
+ auto b = at::randint (1 , 10 , {3 }, {at::kCUDA });
679
+
680
+ auto jit_in = at::clone (in);
681
+ auto jit_w = at::clone (w);
682
+ auto jit_b = at::clone (b);
683
+
684
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_w, jit_b});
685
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
686
+
687
+ auto trt_in = at::clone (in);
688
+ auto trt_w = at::clone (w);
689
+ auto trt_b = at::clone (b);
690
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_w, trt_b});
691
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
692
+
693
+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
694
+
695
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
696
+ }
697
+
573
698
TEST (Converters, ATenConvolutionWithGroupConvertsCorrectly) {
574
699
const auto graph = R"IR(
575
700
graph(%0 : Tensor,
0 commit comments