Skip to content

Commit 91bf074

Browse files
authored
Merge pull request #238 from uni19/fix_deconv
fix deconv with groups > 1
2 parents 93173f8 + b696497 commit 91bf074

File tree

2 files changed

+103
-72
lines changed

2 files changed

+103
-72
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patter
1717
bool deterministic, bool cudnn_enabled) -> (Tensor))SIG",
1818
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1919
auto in = args[0].ITensor(); // assumes non-static input Tensor
20-
2120
auto w = Weights(ctx, args[1].unwrapToTensor());
2221
auto stride = util::toDims(args[3].unwrapToIntList());
2322
LOG_DEBUG("stride: " << stride);
@@ -29,36 +28,45 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patter
2928
auto out_padding = util::toDims(args[7].unwrapToIntList());
3029
LOG_DEBUG("out_padding: " << out_padding);
3130
int64_t groups = args[8].unwrapToInt();
31+
LOG_DEBUG("groups: " << groups);
3232

3333
nvinfer1::ILayer* new_layer;
3434
if (transposed) {
35-
nvinfer1::IDeconvolutionLayer* deconv;
35+
Weights bias;
3636
if (args[2].IValue()->isTensor()) {
37-
Weights b(ctx, args[2].IValue()->toTensor());
38-
deconv = ctx->net->addDeconvolutionNd(*in, w.num_input_maps, w.kernel_shape, w.data, b.data);
37+
bias = Weights(ctx, args[2].unwrapToTensor());
3938
} else {
40-
deconv = ctx->net->addDeconvolutionNd(*in, w.num_input_maps, w.kernel_shape, w.data, {});
39+
bias = Weights(ctx, torch::zeros(args[1].unwrapToTensor().sizes()[1] * groups));
4140
}
4241

42+
// shape of deconvolution's weight: [in, out/groups, ...]
43+
auto deconv = ctx->net->addDeconvolutionNd(
44+
*in, args[1].unwrapToTensor().sizes()[1] * groups, w.kernel_shape, w.data, bias.data);
4345
TRTORCH_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n);
4446

4547
deconv->setStrideNd(stride);
4648
deconv->setPaddingNd(padding);
4749
#if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR == 1)
4850
deconv->setDilationNd(dilation);
4951
deconv->setNbGroups(groups);
52+
#else
53+
TRTORCH_CHECK(groups == 1, "for deconv with groups > 1, require TensorRT version >= 7.1");
54+
for (auto it = dilation.begin(); it != dilation.end(); ++it) {
55+
TRTORCH_CHECK(*it == 1, "for deconv with dilation > 1, require TensorRT version >= 7.1");
56+
}
5057
#endif
5158
new_layer = deconv;
5259
} else {
53-
nvinfer1::IConvolutionLayer* conv;
60+
Weights bias;
5461
if (args[2].IValue()->isTensor()) {
55-
Weights b(ctx, args[2].unwrapToTensor());
56-
conv = ctx->net->addConvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, b.data);
62+
bias = Weights(ctx, args[2].unwrapToTensor());
5763
} else {
58-
Weights b(ctx, torch::zeros(args[1].unwrapToTensor().sizes()[0]));
59-
conv = ctx->net->addConvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, b.data);
64+
bias = Weights(ctx, torch::zeros(args[1].unwrapToTensor().sizes()[0]));
6065
}
6166

67+
// shape of convolution's weight: [out, in/groups, ...]
68+
auto conv =
69+
ctx->net->addConvolutionNd(*in, args[1].unwrapToTensor().sizes()[0], w.kernel_shape, w.data, bias.data);
6270
TRTORCH_CHECK(conv, "Unable to create convolution layer from node: " << *n);
6371

6472
conv->setStrideNd(stride);

tests/core/converters/test_conv_deconv.cpp

Lines changed: 85 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -532,65 +532,88 @@ TEST(Converters, ATenConvTransposeWithPaddingConvertsCorrectly) {
532532
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
533533
}
534534

535-
// TEST(Converters, ATenConvolutionWithDialationConvertsCorrectly) {
536-
// const auto graph = R"IR(
537-
// graph(%0 : Tensor,
538-
// %1 : Float(8, 3, 5, 5),
539-
// %2 : Float(8)):
540-
// %3 : int = prim::Constant[value=1]()
541-
// %4 : int = prim::Constant[value=0]()
542-
// %5 : int = prim::Constant[value=2]()
543-
// %6 : int = prim::Constant[value=0]()
544-
// %7 : bool = prim::Constant[value=0]()
545-
// %8 : int[] = prim::ListConstruct(%3, %3)
546-
// %9 : int[] = prim::ListConstruct(%4, %4)
547-
// %10 : int[] = prim::ListConstruct(%5, %5)
548-
// %11 : int[] = prim::ListConstruct(%6, %6)
549-
// %12 : int = prim::Constant[value=1]()
550-
// %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11,
551-
// %12, %7, %7, %7) return (%13))IR";
552-
553-
// conv_test_helper(graph);
554-
// }
555-
556-
// TEST(Converters, ATenConvolutionWithPostPaddingConvertsCorrectly) {
557-
// const auto graph = R"IR(
558-
// graph(%0 : Tensor,
559-
// %1 : Float(8, 3, 5, 5),
560-
// %2 : Float(8)):
561-
// %3 : int = prim::Constant[value=1]()
562-
// %4 : int = prim::Constant[value=0]()
563-
// %5 : int = prim::Constant[value=1]()
564-
// %6 : int = prim::Constant[value=2]()
565-
// %7 : bool = prim::Constant[value=0]()
566-
// %8 : int[] = prim::ListConstruct(%3, %3)
567-
// %9 : int[] = prim::ListConstruct(%4, %4)
568-
// %10 : int[] = prim::ListConstruct(%5, %5)
569-
// %11 : int[] = prim::ListConstruct(%6, %6)
570-
// %12 : int = prim::Constant[value=1]()
571-
// %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11,
572-
// %12, %7, %7, %7) return (%13))IR";
573-
574-
// conv_test_helper(graph);
575-
// }
576-
577-
// TEST(Converters, ATenConvolutionWithGroupConvertsCorrectly) {
578-
// const auto graph = R"IR(
579-
// graph(%0 : Tensor,
580-
// %1 : Float(8, 3, 5, 5),
581-
// %2 : Float(8)):
582-
// %3 : int = prim::Constant[value=1]()
583-
// %4 : int = prim::Constant[value=0]()
584-
// %5 : int = prim::Constant[value=1]()
585-
// %6 : int = prim::Constant[value=0]()
586-
// %7 : bool = prim::Constant[value=0]()
587-
// %8 : int[] = prim::ListConstruct(%3, %3)
588-
// %9 : int[] = prim::ListConstruct(%4, %4)
589-
// %10 : int[] = prim::ListConstruct(%5, %5)
590-
// %11 : int[] = prim::ListConstruct(%6, %6)
591-
// %12 : int = prim::Constant[value=2]()
592-
// %13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11,
593-
// %12, %7, %7, %7) return (%13))IR";
594-
595-
// conv_test_helper(graph);
596-
// }
535+
TEST(Converters, ATenConvolutionWithGroupConvertsCorrectly) {
536+
const auto graph = R"IR(
537+
graph(%0 : Tensor,
538+
%1 : Float(8:48, 1:16, 2:4, 2:1),
539+
%2 : Float(8:1)):
540+
%3 : int = prim::Constant[value=1]()
541+
%4 : int = prim::Constant[value=2]()
542+
%5 : int = prim::Constant[value=1]()
543+
%6 : int = prim::Constant[value=0]()
544+
%7 : bool = prim::Constant[value=0]()
545+
%8 : int[] = prim::ListConstruct(%3, %3)
546+
%9 : int[] = prim::ListConstruct(%4, %4)
547+
%10 : int[] = prim::ListConstruct(%5, %5)
548+
%11 : int[] = prim::ListConstruct(%6, %6)
549+
%12 : int = prim::Constant[value=4]()
550+
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
551+
return (%13))IR";
552+
553+
auto g = std::make_shared<torch::jit::Graph>();
554+
torch::jit::parseIR(graph, &*g);
555+
556+
auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA});
557+
auto w = at::randint(1, 10, {8, 1, 2, 2}, {at::kCUDA});
558+
auto b = at::randint(1, 10, {8}, {at::kCUDA});
559+
560+
auto jit_in = at::clone(in);
561+
auto jit_w = at::clone(w);
562+
auto jit_b = at::clone(b);
563+
564+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
565+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
566+
567+
auto trt_in = at::clone(in);
568+
auto trt_w = at::clone(w);
569+
auto trt_b = at::clone(b);
570+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
571+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
572+
573+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
574+
575+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
576+
}
577+
578+
TEST(Converters, ATenConvTransposeWithGroupConvertsCorrectly) {
579+
const auto graph = R"IR(
580+
graph(%0 : Tensor,
581+
%1 : Float(8:56, 4:16, 3:3, 3:1),
582+
%2 : Float(16:1)):
583+
%3 : int = prim::Constant[value=1]()
584+
%4 : int = prim::Constant[value=1]()
585+
%5 : int = prim::Constant[value=1]()
586+
%6 : int = prim::Constant[value=0]()
587+
%7 : bool = prim::Constant[value=1]()
588+
%8 : int[] = prim::ListConstruct(%3, %3)
589+
%9 : int[] = prim::ListConstruct(%4, %4)
590+
%10 : int[] = prim::ListConstruct(%5, %5)
591+
%11 : int[] = prim::ListConstruct(%6, %6)
592+
%12 : int = prim::Constant[value=4]()
593+
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7)
594+
return (%13))IR";
595+
596+
auto g = std::make_shared<torch::jit::Graph>();
597+
torch::jit::parseIR(graph, &*g);
598+
599+
auto in = at::randint(1, 10, {1, 8, 5, 5}, {at::kCUDA});
600+
auto w = at::randint(1, 10, {8, 4, 3, 3}, {at::kCUDA});
601+
auto b = at::randint(1, 10, {16}, {at::kCUDA});
602+
603+
auto jit_in = at::clone(in);
604+
auto jit_w = at::clone(w);
605+
auto jit_b = at::clone(b);
606+
607+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
608+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
609+
610+
auto trt_in = at::clone(in);
611+
auto trt_w = at::clone(w);
612+
auto trt_b = at::clone(b);
613+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
614+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
615+
616+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
617+
618+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
619+
}

0 commit comments

Comments
 (0)