Skip to content

Commit b756063

Browse files
authored
Speed depthwise transposed conv2d. (#11740)
* Speed depthwise transposed conv2d.
1 parent 8630ba2 commit b756063

File tree

5 files changed

+135
-24
lines changed

5 files changed

+135
-24
lines changed

paddle/fluid/operators/conv_transpose_op.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
302302

303303
namespace ops = paddle::operators;
304304

305+
// conv2d_transpose
305306
REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp,
306307
ops::Conv2DTransposeOpMaker,
307308
paddle::framework::DefaultGradOpDescMaker<true>);
@@ -317,6 +318,7 @@ REGISTER_OP_CPU_KERNEL(
317318
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
318319
double>);
319320

321+
// conv3d_transpose
320322
REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp,
321323
ops::Conv3DTransposeOpMaker,
322324
paddle::framework::DefaultGradOpDescMaker<true>);
@@ -331,3 +333,19 @@ REGISTER_OP_CPU_KERNEL(
331333
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
332334
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
333335
double>);
336+
337+
// depthwise conv2d_transpose
338+
REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp,
339+
ops::Conv2DTransposeOpMaker,
340+
paddle::framework::DefaultGradOpDescMaker<true>);
341+
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad);
342+
343+
REGISTER_OP_CPU_KERNEL(
344+
depthwise_conv2d_transpose,
345+
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
346+
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
347+
REGISTER_OP_CPU_KERNEL(
348+
depthwise_conv2d_transpose_grad,
349+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
350+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
351+
double>);

paddle/fluid/operators/conv_transpose_op.cu.cc

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,28 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/conv_transpose_op.h"
1616

1717
namespace ops = paddle::operators;
18+
using CUDA = paddle::platform::CUDADeviceContext;
1819

19-
REGISTER_OP_CUDA_KERNEL(
20-
conv2d_transpose,
21-
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, float>,
22-
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, double>);
23-
REGISTER_OP_CUDA_KERNEL(
24-
conv2d_transpose_grad,
25-
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
26-
float>,
27-
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
28-
double>);
29-
30-
REGISTER_OP_CUDA_KERNEL(
31-
conv3d_transpose,
32-
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, float>,
33-
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, double>);
34-
REGISTER_OP_CUDA_KERNEL(
35-
conv3d_transpose_grad,
36-
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
37-
float>,
38-
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
39-
double>);
20+
// conv2d
21+
REGISTER_OP_CUDA_KERNEL(conv2d_transpose,
22+
ops::GemmConvTransposeKernel<CUDA, float>,
23+
ops::GemmConvTransposeKernel<CUDA, double>);
24+
REGISTER_OP_CUDA_KERNEL(conv2d_transpose_grad,
25+
ops::GemmConvTransposeGradKernel<CUDA, float>,
26+
ops::GemmConvTransposeGradKernel<CUDA, double>);
27+
28+
// conv3d
29+
REGISTER_OP_CUDA_KERNEL(conv3d_transpose,
30+
ops::GemmConvTransposeKernel<CUDA, float>,
31+
ops::GemmConvTransposeKernel<CUDA, double>);
32+
REGISTER_OP_CUDA_KERNEL(conv3d_transpose_grad,
33+
ops::GemmConvTransposeGradKernel<CUDA, float>,
34+
ops::GemmConvTransposeGradKernel<CUDA, double>);
35+
36+
// depthwise conv2d
37+
REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose,
38+
ops::DepthwiseConvTransposeKernel<CUDA, float>,
39+
ops::DepthwiseConvTransposeKernel<CUDA, double>);
40+
REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose_grad,
41+
ops::DepthwiseConvTransposeGradKernel<CUDA, float>,
42+
ops::DepthwiseConvTransposeGradKernel<CUDA, double>);

paddle/fluid/operators/conv_transpose_op.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
1919
#include "paddle/fluid/operators/math/blas.h"
20+
#include "paddle/fluid/operators/math/depthwise_conv.h"
2021
#include "paddle/fluid/operators/math/im2col.h"
2122
#include "paddle/fluid/operators/math/vol2col.h"
2223

@@ -316,5 +317,74 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
316317
}
317318
}
318319
};
320+
321+
template <typename DeviceContext, typename T>
322+
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
323+
public:
324+
void Compute(const framework::ExecutionContext& context) const override {
325+
const Tensor* input = context.Input<Tensor>("Input");
326+
Tensor filter = *context.Input<Tensor>("Filter");
327+
Tensor* output = context.Output<Tensor>("Output");
328+
output->mutable_data<T>(context.GetPlace());
329+
330+
int groups = context.Attr<int>("groups");
331+
PADDLE_ENFORCE_EQ(groups, filter.dims()[0]);
332+
333+
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
334+
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
335+
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
336+
for (auto v : dilations) {
337+
PADDLE_ENFORCE_EQ(v, 1);
338+
}
339+
340+
output->mutable_data<T>(context.GetPlace());
341+
auto& dev_ctx = context.template device_context<DeviceContext>();
342+
math::SetConstant<DeviceContext, T> set_zero;
343+
set_zero(dev_ctx, output, static_cast<T>(0));
344+
345+
math::DepthwiseConvInputGradFunctor<DeviceContext, T>
346+
depthwiseConvInputGrad;
347+
depthwiseConvInputGrad(dev_ctx, *output, filter, *input, strides, paddings,
348+
output);
349+
}
350+
};
351+
352+
template <typename DeviceContext, typename T>
353+
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
354+
public:
355+
void Compute(const framework::ExecutionContext& context) const override {
356+
const Tensor* input = context.Input<Tensor>("Input");
357+
const Tensor* output_grad =
358+
context.Input<Tensor>(framework::GradVarName("Output"));
359+
Tensor* input_grad =
360+
context.Output<Tensor>(framework::GradVarName("Input"));
361+
Tensor* filter_grad =
362+
context.Output<Tensor>(framework::GradVarName("Filter"));
363+
Tensor filter = *context.Input<Tensor>("Filter");
364+
365+
if (!input_grad && !filter_grad) return;
366+
367+
auto& dev_ctx = context.template device_context<DeviceContext>();
368+
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
369+
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
370+
371+
if (input_grad) {
372+
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
373+
depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings,
374+
input_grad);
375+
}
376+
377+
if (filter_grad) {
378+
math::SetConstant<DeviceContext, T> set_zero;
379+
filter_grad->mutable_data<T>(context.GetPlace());
380+
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
381+
382+
math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
383+
depthwiseConvFilterGrad;
384+
depthwiseConvFilterGrad(dev_ctx, *output_grad, *input, strides, paddings,
385+
filter_grad);
386+
}
387+
}
388+
};
319389
} // namespace operators
320390
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,10 +2334,17 @@ def conv2d_transpose(input,
23342334
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
23352335
conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3)
23362336
"""
2337-
helper = LayerHelper("conv2d_transpose", **locals())
2337+
2338+
input_channel = input.shape[1]
2339+
2340+
op_type = 'conv2d_transpose'
2341+
if (input_channel == groups and num_filters == input_channel and
2342+
not use_cudnn):
2343+
op_type = 'depthwise_conv2d_transpose'
2344+
2345+
helper = LayerHelper(op_type, **locals())
23382346
if not isinstance(input, Variable):
23392347
raise TypeError("Input of conv2d_transpose must be Variable")
2340-
input_channel = input.shape[1]
23412348

23422349
padding = utils.convert_to_list(padding, 2, 'padding')
23432350
stride = utils.convert_to_list(stride, 2, 'stride')
@@ -2371,7 +2378,7 @@ def conv2d_transpose(input,
23712378

23722379
pre_bias = helper.create_tmp_variable(dtype=input.dtype)
23732380
helper.append_op(
2374-
type='conv2d_transpose',
2381+
type=op_type,
23752382
inputs={'Input': [input],
23762383
'Filter': [img_filter]},
23772384
outputs={'Output': pre_bias},

python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,19 @@ def init_op_type(self):
242242
self.op_type = "conv2d_transpose"
243243

244244

245+
class TestDepthwiseConvTranspose(TestConv2dTransposeOp):
246+
def init_test_case(self):
247+
self.pad = [1, 1]
248+
self.stride = [2, 2]
249+
self.dilations = [1, 1]
250+
self.input_size = [2, 8, 16, 16] # NCHW
251+
self.groups = 8
252+
assert np.mod(self.input_size[1], self.groups) == 0
253+
f_c = self.input_size[1] / self.groups
254+
self.filter_size = [self.input_size[1], f_c, 4, 4]
255+
self.op_type = "depthwise_conv2d_transpose"
256+
257+
245258
# Please Don't remove the following code.
246259
# Currently, CI use cudnn V5.0 which not support dilation conv.
247260
# class TestCUDNNWithDilation(TestWithDilation):

0 commit comments

Comments
 (0)