Skip to content

Commit 7b212bf

Browse files
authored
Merge pull request #1234 from ruoqianguo/deconv_out_padding
Add outputPadding in deconv
2 parents 293db8b + dca6a9f commit 7b212bf

File tree

2 files changed

+199
-3
lines changed

2 files changed

+199
-3
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,34 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
132132

133133
nvinfer1::ILayer* new_layer;
134134
if (transposed) {
135+
// Refer to
136+
// https://github.com/onnx/onnx-tensorrt/blob/c3cfcbc8248c6bd007e6630af2085df5e4834b42/builtin_op_importers.cpp#L734
137+
nvinfer1::Dims begPadding = padding;
138+
bool hasOutputPadding = false;
139+
int nbSpatialDims = out_padding.nbDims;
140+
// When there is out_padding, if padding is larger than out_padding, just adjust padding Or reduce out_padding as
141+
// minimum as possible.
142+
for (int i = 0; i < nbSpatialDims; ++i) {
143+
if (padding.d[i] - out_padding.d[i] >= 0) {
144+
padding.d[i] -= out_padding.d[i];
145+
out_padding.d[i] = 0;
146+
} else {
147+
// Reduce out_padding as possible.
148+
out_padding.d[i] -= padding.d[i];
149+
padding.d[i] = 0;
150+
hasOutputPadding = true;
151+
}
152+
}
153+
135154
// shape of deconvolution's weight: [in, out/groups, ...]
136-
auto deconv = ctx->net->addDeconvolutionNd(*in, w.shape.d[1] * groups, w.kernel_shape, w.data, bias.data);
155+
// If there is still output padding, remove the bias. Bias will be added below.
156+
auto deconv = ctx->net->addDeconvolutionNd(
157+
*in, w.shape.d[1] * groups, w.kernel_shape, w.data, hasOutputPadding ? nvinfer1::Weights{} : bias.data);
137158
TORCHTRT_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n);
138159

139160
deconv->setStrideNd(stride);
140-
deconv->setPaddingNd(padding);
161+
deconv->setPrePadding(begPadding);
162+
deconv->setPostPadding(padding);
141163
#if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR >= 1)
142164
deconv->setDilationNd(dilation);
143165
deconv->setNbGroups(groups);
@@ -147,7 +169,56 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
147169
TORCHTRT_CHECK(dilation.d[idx] == 1, "for deconv with dilation > 1, require TensorRT version >= 7.1");
148170
}
149171
#endif
150-
new_layer = deconv;
172+
if (hasOutputPadding) {
173+
LOG_DEBUG("Padding output deconvolution tensor with:" << out_padding);
174+
175+
// Add padding layer
176+
nvinfer1::ITensor* start;
177+
nvinfer1::ITensor* totalPadding;
178+
auto in_nbDims = orig_dims.nbDims;
179+
std::vector<int32_t> startVec(in_nbDims, 0);
180+
std::vector<int32_t> totalPaddingVec(in_nbDims, 0);
181+
int32_t diff = in_nbDims - out_padding.nbDims;
182+
for (int32_t i = diff; i < in_nbDims; i++) {
183+
int32_t idx = i - diff;
184+
startVec[i] = 0; // Don't need begin padding, only post padding
185+
totalPaddingVec[i] = out_padding.d[idx];
186+
}
187+
start = tensor_to_const(ctx, torch::tensor(startVec, torch::kInt32));
188+
totalPadding = tensor_to_const(ctx, torch::tensor(totalPaddingVec, torch::kInt32));
189+
190+
nvinfer1::ITensor* tensorPtr = deconv->getOutput(0);
191+
nvinfer1::ITensor* deconvOutShape = ctx->net->addShape(*tensorPtr)->getOutput(0);
192+
const auto size =
193+
ctx->net->addElementWise(*deconvOutShape, *totalPadding, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
194+
195+
nvinfer1::Dims stride;
196+
stride.nbDims = in_nbDims;
197+
for (size_t i = 0; i < in_nbDims; i++) {
198+
stride.d[i] = 1;
199+
}
200+
const auto& dummy = stride;
201+
auto* sliceLayer = ctx->net->addSlice(*tensorPtr, dummy, dummy, stride);
202+
sliceLayer->setInput(1, *start);
203+
sliceLayer->setInput(2, *size);
204+
sliceLayer->setMode(nvinfer1::SliceMode::kFILL);
205+
tensorPtr = sliceLayer->getOutput(0);
206+
207+
nvinfer1::Dims constantDims;
208+
constantDims.nbDims = in_nbDims;
209+
for (size_t i = 0; i < in_nbDims; i++) {
210+
constantDims.d[i] = 1;
211+
}
212+
constantDims.d[diff - 1] =
213+
bias.shape.d[0]; // Set C dimension to bias dim and other dimensions to 1 to enable broadcast
214+
auto const_layer = ctx->net->addConstant(constantDims, bias.data);
215+
auto add_bias_layer =
216+
ctx->net->addElementWise(*tensorPtr, *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUM);
217+
218+
new_layer = add_bias_layer;
219+
} else {
220+
new_layer = deconv;
221+
}
151222
} else {
152223
// shape of convolution's weight: [out, in/groups, ...]
153224
auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data);

tests/core/conversion/converters/test_conv_deconv.cpp

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,131 @@ TEST(Converters, ATenConvTransposeWithPaddingConvertsCorrectly) {
570570
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
571571
}
572572

573+
TEST(Converters, ATenConv1dTransposeWithPaddingOutPaddingConvertsCorrectly) {
574+
const auto graph = R"IR(
575+
graph(%0 : Tensor,
576+
%1 : Float(4, 3, 3, strides=[9, 3, 1])):
577+
%2 : None = prim::Constant()
578+
%3 : int = prim::Constant[value=2]()
579+
%4 : int = prim::Constant[value=1]()
580+
%5 : int = prim::Constant[value=1]()
581+
%6 : int = prim::Constant[value=1]()
582+
%7 : bool = prim::Constant[value=1]()
583+
%8 : int[] = prim::ListConstruct(%3)
584+
%9 : int[] = prim::ListConstruct(%4)
585+
%10 : int[] = prim::ListConstruct(%5)
586+
%11 : int[] = prim::ListConstruct(%6)
587+
%12 : int = prim::Constant[value=1]()
588+
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7, %7)
589+
return (%13))IR";
590+
591+
auto g = std::make_shared<torch::jit::Graph>();
592+
torch::jit::parseIR(graph, g.get());
593+
594+
auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
595+
auto w = at::randint(1, 2, {3, 4, 3}, {at::kCUDA});
596+
597+
auto jit_in = at::clone(in);
598+
auto jit_w = at::clone(w);
599+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w});
600+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
601+
602+
auto trt_in = at::clone(in);
603+
auto trt_w = at::clone(w);
604+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w});
605+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
606+
607+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
608+
609+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
610+
}
611+
612+
TEST(Converters, ATenConvTransposeWithPaddingOutPaddingConvertsCorrectly) {
613+
const auto graph = R"IR(
614+
graph(%0 : Tensor,
615+
%1 : Float(4, 3, 4, 4, strides=[48, 16, 4, 1]),
616+
%2 : Float(4)):
617+
%3 : int = prim::Constant[value=2]()
618+
%4 : int = prim::Constant[value=2]()
619+
%5 : int = prim::Constant[value=1]()
620+
%6 : int = prim::Constant[value=1]()
621+
%7 : bool = prim::Constant[value=1]()
622+
%8 : int[] = prim::ListConstruct(%3, %3)
623+
%9 : int[] = prim::ListConstruct(%4, %4)
624+
%10 : int[] = prim::ListConstruct(%5, %5)
625+
%11 : int[] = prim::ListConstruct(%6, %6)
626+
%12 : int = prim::Constant[value=1]()
627+
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7, %7)
628+
return (%13))IR";
629+
630+
auto g = std::make_shared<torch::jit::Graph>();
631+
torch::jit::parseIR(graph, g.get());
632+
633+
auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA});
634+
auto w = at::randint(1, 10, {4, 3, 2, 2}, {at::kCUDA});
635+
auto b = at::randint(1, 10, {3}, {at::kCUDA});
636+
637+
auto jit_in = at::clone(in);
638+
auto jit_w = at::clone(w);
639+
auto jit_b = at::clone(b);
640+
641+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w, jit_b});
642+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
643+
644+
auto trt_in = at::clone(in);
645+
auto trt_w = at::clone(w);
646+
auto trt_b = at::clone(b);
647+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w, trt_b});
648+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
649+
650+
auto trt = trt_results[0];
651+
652+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
653+
}
654+
655+
TEST(Converters, ATenConvTransposeOutPaddingBiggerThanPaddingConvertsCorrectly) {
656+
const auto graph = R"IR(
657+
graph(%0 : Tensor,
658+
%1 : Float(4, 3, 4, 4, strides=[48, 16, 4, 1]),
659+
%2 : Float(4)):
660+
%3 : int = prim::Constant[value=4]()
661+
%4 : int = prim::Constant[value=2]()
662+
%5 : int = prim::Constant[value=1]()
663+
%6 : int = prim::Constant[value=3]()
664+
%7 : bool = prim::Constant[value=1]()
665+
%8 : int[] = prim::ListConstruct(%3, %3)
666+
%9 : int[] = prim::ListConstruct(%4, %4)
667+
%10 : int[] = prim::ListConstruct(%5, %5)
668+
%11 : int[] = prim::ListConstruct(%6, %6)
669+
%12 : int = prim::Constant[value=1]()
670+
%13 : Tensor = aten::_convolution(%0, %1, %2, %8, %9, %10, %7, %11, %12, %7, %7, %7, %7)
671+
return (%13))IR";
672+
673+
auto g = std::make_shared<torch::jit::Graph>();
674+
torch::jit::parseIR(graph, g.get());
675+
676+
auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA});
677+
auto w = at::randint(1, 10, {4, 3, 2, 2}, {at::kCUDA});
678+
auto b = at::randint(1, 10, {3}, {at::kCUDA});
679+
680+
auto jit_in = at::clone(in);
681+
auto jit_w = at::clone(w);
682+
auto jit_b = at::clone(b);
683+
684+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {jit_w, jit_b});
685+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
686+
687+
auto trt_in = at::clone(in);
688+
auto trt_w = at::clone(w);
689+
auto trt_b = at::clone(b);
690+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_w, trt_b});
691+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
692+
693+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
694+
695+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
696+
}
697+
573698
TEST(Converters, ATenConvolutionWithGroupConvertsCorrectly) {
574699
const auto graph = R"IR(
575700
graph(%0 : Tensor,

0 commit comments

Comments
 (0)