Skip to content

Commit 4fc9f55

Browse files
author
chengduo
authored
Merge pull request #5472 from chengduoZH/refine_im2col
Add dilations for conv2d and optimize conv2d code
2 parents 09866fb + 00e0881 commit 4fc9f55

17 files changed

+944
-635
lines changed

paddle/operators/conv_cudnn_op.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
2222
CudnnConvOpMaker(framework::OpProto* proto,
2323
framework::OpAttrChecker* op_checker)
2424
: Conv2DOpMaker(proto, op_checker) {
25-
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
26-
.SetDefault(std::vector<int>{1, 1});
2725
AddAttr<int>("workspace_size_MB",
2826
"workspace size for cudnn, in MB, "
2927
"workspace is a section of GPU memory which will be "

paddle/operators/conv_op.cc

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
3030
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
3131
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
3232
int groups = ctx->Attrs().Get<int>("groups");
33+
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
3334
int input_channels = in_dims[1];
3435
int output_channels = filter_dims[0];
3536

@@ -52,9 +53,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
5253
"The number of output channels should be divided by groups.");
5354

5455
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
55-
for (size_t i = 0; i < paddings.size(); ++i) {
56+
for (size_t i = 0; i < strides.size(); ++i) {
57+
PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] -
58+
(dilations[i] * (filter_dims[i + 2] - 1) + 1) >
59+
0,
60+
"Due to the settings of paddings, filter_dims and "
61+
"dilations, the output size is less than 0, please check "
62+
"again.");
5663
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
57-
paddings[i], strides[i]));
64+
dilations[i], paddings[i], strides[i]));
5865
}
5966
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
6067
}
@@ -78,9 +85,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
7885
AddOutput("Output",
7986
"(Tensor) The output tensor of convolution operator. "
8087
"The format of output tensor is also NCHW.");
81-
AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
88+
AddAttr<std::vector<int>>("strides",
89+
"(vector<int> default:{1, 1}), the "
90+
"strides(h_stride, w_stride) of "
91+
"convolution operator.")
8292
.SetDefault({1, 1});
83-
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.")
93+
AddAttr<std::vector<int>>("paddings",
94+
"(vector<int> default:{0, 0}), the "
95+
"paddings(h_pad, w_pad) of "
96+
"convolution operator.")
8497
.SetDefault({0, 0});
8598
AddAttr<int>(
8699
"groups",
@@ -90,15 +103,20 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
90103
"first half of the input channels, while the second half of the filters "
91104
"is only connected to the second half of the input channels.")
92105
.SetDefault(1);
106+
AddAttr<std::vector<int>>("dilations",
107+
"(vector<int> default:{1, 1}), the "
108+
"dilations(h_dilation, w_dilation) of "
109+
"convolution operator.")
110+
.SetDefault({1, 1});
93111
AddComment(R"DOC(
94112
Convolution Operator.
95113
96114
The convolution operation calculates the output based on the input, filter
97-
and strides, paddings, groups parameters. The size of each dimension of the
115+
and strides, paddings, groups, dilations parameters. The size of each dimension of the
98116
parameters is checked in the infer-shape.
99117
Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch
100118
size, C is the number of channels, H is the height of the feature, and W is
101-
the width of the feature. Parameters(ksize, strides, paddings) are two elements.
119+
the width of the feature. Parameters(ksize, strides, paddings, dilations) are two elements.
102120
These two elements represent height and width, respectively.
103121
The input(X) size and output(Out) size may be different.
104122
@@ -109,8 +127,8 @@ The input(X) size and output(Out) size may be different.
109127
Output:
110128
Output shape: (N, C_out, H_out, W_out)
111129
where
112-
H_out = (H_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1;
113-
W_out = (W_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1;
130+
H_out = (H_in + 2 * paddings[0] - (dilations[0]*(filter_size[0] - 1) + 1)) / strides[0] + 1;
131+
W_out = (W_in + 2 * paddings[1] - (dilations[1]*(filter_size[1] - 1) + 1)) / strides[1] + 1;
114132
)DOC");
115133
}
116134

@@ -135,13 +153,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
135153
AddOutput("Output",
136154
"(Tensor) The output tensor of convolution operator."
137155
"The format of output tensor is also NCDHW.");
138-
AddAttr<std::vector<int>>(
139-
"strides",
140-
"(vector, default:{0, 0, 0}), the strides of convolution operator.")
156+
AddAttr<std::vector<int>>("strides",
157+
"(vector<int>, default:{1, 1, 1}), the "
158+
"strides(d_stride, h_stride, w_stride) of "
159+
"convolution operator.")
141160
.SetDefault({1, 1, 1});
142-
AddAttr<std::vector<int>>(
143-
"paddings",
144-
"(vector, default:{0, 0, 0}), the paddings of convolution operator.")
161+
AddAttr<std::vector<int>>("paddings",
162+
"(vector<int>, default:{0, 0, 0}), the "
163+
"paddings(d_pad, h_pad, w_pad) of convolution "
164+
"operator.")
145165
.SetDefault({0, 0, 0});
146166
AddAttr<int>(
147167
"groups",
@@ -151,6 +171,12 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
151171
"first half of the input channels, while the second half of the filters "
152172
"is only connected to the second half of the input channels.")
153173
.SetDefault(1);
174+
AddAttr<std::vector<int>>("dilations",
175+
"(vector<int> default:{1, 1, 1}), the "
176+
"dilations(d_dilation, h_dilation, w_dilation) of "
177+
"convolution operator. Currently, conv3d doesn't "
178+
"support dilation.")
179+
.SetDefault({1, 1, 1});
154180

155181
AddComment(R"DOC(
156182
Convolution3D Operator.

paddle/operators/conv_op.h

Lines changed: 79 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,24 @@ using Tensor = framework::Tensor;
2727

2828
// Base convolution operator definations for other conv
2929
// like operators to reuse the implementation.
30-
inline int OutputSize(int input_size, int filter_size, int padding,
31-
int stride) {
32-
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
30+
inline int OutputSize(int input_size, int filter_size, int dilation,
31+
int padding, int stride) {
32+
const int dkernel = dilation * (filter_size - 1) + 1;
33+
const int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
3334
return output_size;
3435
}
36+
inline bool IsExpand(std::vector<int64_t>& filter_dim,
37+
std::vector<int>& strides, std::vector<int>& paddings,
38+
std::vector<int>& dilations) {
39+
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
40+
for (size_t j = 0; j < strides.size(); ++j) {
41+
filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1);
42+
strides_1 = strides_1 && (strides[j] == 1);
43+
padding_0 = padding_0 && (paddings[j] == 0);
44+
dilation_1 = dilation_1 && (dilations[j] == 1);
45+
}
46+
return !(filter_1 && strides_1 && padding_0 && dilation_1);
47+
}
3548

3649
// Define Op classes in .h file so that other conv
3750
// operator implementations can reuse the code.
@@ -50,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
5063
class ConvOp : public framework::OperatorWithKernel {
5164
public:
5265
using framework::OperatorWithKernel::OperatorWithKernel;
53-
5466
void InferShape(framework::InferShapeContext* ctx) const override;
5567
};
5668

5769
class ConvOpGrad : public framework::OperatorWithKernel {
5870
public:
5971
using framework::OperatorWithKernel::OperatorWithKernel;
60-
6172
void InferShape(framework::InferShapeContext* ctx) const override;
6273
};
6374

@@ -73,9 +84,10 @@ class GemmConvKernel : public framework::OpKernel<T> {
7384
Tensor* output = context.Output<Tensor>("Output");
7485
output->mutable_data<T>(context.GetPlace());
7586

87+
int groups = context.Attr<int>("groups");
7688
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
7789
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
78-
int groups = context.Attr<int>("groups");
90+
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
7991

8092
const int batch_size = static_cast<int>(input->dims()[0]);
8193

@@ -106,14 +118,17 @@ class GemmConvKernel : public framework::OpKernel<T> {
106118
framework::DDim col_matrix_shape =
107119
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
108120

121+
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
109122
Tensor col;
110-
col.mutable_data<T>(col_shape, context.GetPlace());
111123
// col_matrix shares the same piece of data with col,
112124
// but will be reshaped into a two-dimensional matrix shape
113125
// to call the matrix multiplication interface.
114126
Tensor col_matrix;
115-
col_matrix.ShareDataWith(col);
116-
col_matrix.Resize(col_matrix_shape);
127+
if (is_expand) {
128+
col.mutable_data<T>(col_shape, context.GetPlace());
129+
col_matrix.ShareDataWith(col);
130+
col_matrix.Resize(col_matrix_shape);
131+
}
117132

118133
framework::DDim input_shape = framework::slice_ddim(
119134
input->dims(), 1, static_cast<int>(input->dims().size()));
@@ -130,24 +145,30 @@ class GemmConvKernel : public framework::OpKernel<T> {
130145
int in_step = static_cast<int>(input->dims()[1]) / groups;
131146
int out_step = static_cast<int>(output->dims()[1]) / groups;
132147

148+
math::Vol2ColFunctor<Place, T> vol2col;
149+
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
150+
133151
for (int i = 0; i < batch_size; i++) {
134152
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
135153
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
154+
136155
for (int g = 0; g < groups; g++) {
137156
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
138157

139-
if (filter_shape_vec.size() == 2) {
158+
if (!is_expand) {
159+
col.ShareDataWith(in_slice);
160+
col_matrix.ShareDataWith(col);
161+
col_matrix.Resize(col_matrix_shape);
162+
} else if (filter_shape_vec.size() == 2) {
140163
// im2col
141-
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
142-
im2col(context.device_context(), in_slice, col, strides[0],
143-
strides[1], paddings[0], paddings[0], paddings[1],
144-
paddings[1]);
164+
im2col(context.device_context(), in_slice, dilations, strides,
165+
std::vector<int>{paddings[0], paddings[1], paddings[0],
166+
paddings[1]},
167+
&col);
145168
} else if (filter_shape_vec.size() == 3) {
146169
// vol2col
147-
math::Vol2ColFunctor<Place, T> vol2col;
148-
vol2col(context.device_context(), in_slice, col, strides[0],
149-
strides[1], strides[2], paddings[0], paddings[1],
150-
paddings[2]);
170+
vol2col(context.device_context(), in_slice, dilations, strides,
171+
paddings, &col);
151172
}
152173

153174
// gemm
@@ -178,9 +199,10 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
178199

179200
if (!input_grad && !filter_grad) return;
180201

202+
int groups = context.Attr<int>("groups");
181203
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
182204
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
183-
int groups = context.Attr<int>("groups");
205+
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
184206

185207
const int batch_size = static_cast<int>(input->dims()[0]);
186208

@@ -230,21 +252,27 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
230252
int in_step = static_cast<int>(input->dims()[1]) / groups;
231253
int out_step = static_cast<int>(output_grad->dims()[1]) / groups;
232254

255+
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
233256
Tensor col;
234257
// col_matrix shares the same piece of data with col,
235258
// but will be reshaped into a two-dimensional matrix shape
236259
// to call the matrix multiplication interface.
237260
Tensor col_matrix;
238-
col.mutable_data<T>(col_shape, context.GetPlace());
239-
col_matrix.ShareDataWith(col);
240-
col_matrix.Resize(col_matrix_shape);
261+
if (is_expand) {
262+
col.mutable_data<T>(col_shape, context.GetPlace());
263+
col_matrix.ShareDataWith(col);
264+
col_matrix.Resize(col_matrix_shape);
265+
}
241266

242267
math::SetConstant<Place, T> set_zero;
243268

244269
if (input_grad) {
245270
input_grad->mutable_data<T>(context.GetPlace());
246271
set_zero(context.device_context(), input_grad, static_cast<T>(0));
247272

273+
math::Col2VolFunctor<Place, T> col2vol;
274+
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
275+
248276
for (int i = 0; i < batch_size; i++) {
249277
Tensor out_grad_batch =
250278
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
@@ -254,24 +282,26 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
254282
Tensor out_grad_slice =
255283
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
256284
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
257-
math::matmul<Place, T>(context.device_context(), filter_slice, true,
258-
out_grad_slice, false, T(1.0), &col_matrix,
259-
T(0.0));
260-
// col2im
285+
261286
Tensor in_grad_slice =
262287
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
263288

264-
if (filter_shape_vec.size() == 2) {
265-
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
266-
col2im(context.device_context(), in_grad_slice, col, strides[0],
267-
strides[1], paddings[0], paddings[0], paddings[1],
268-
paddings[1]);
289+
if (!is_expand) {
290+
col_matrix.ShareDataWith(in_grad_slice);
291+
col_matrix.Resize(col_matrix_shape);
292+
}
293+
math::matmul<Place, T>(context.device_context(), filter_slice, true,
294+
out_grad_slice, false, T(1.0), &col_matrix,
295+
T(0.0));
269296

270-
} else if (filter_shape_vec.size() == 3) {
271-
math::Col2VolFunctor<Place, T> col2vol;
272-
col2vol(context.device_context(), in_grad_slice, col, strides[0],
273-
strides[1], strides[2], paddings[0], paddings[1],
274-
paddings[2]);
297+
if (is_expand && filter_shape_vec.size() == 2) {
298+
col2im(context.device_context(), col, dilations, strides,
299+
std::vector<int>{paddings[0], paddings[1], paddings[0],
300+
paddings[1]},
301+
&in_grad_slice);
302+
} else if (is_expand && filter_shape_vec.size() == 3) {
303+
col2vol(context.device_context(), col, dilations, strides, paddings,
304+
&in_grad_slice);
275305
}
276306
}
277307
}
@@ -282,7 +312,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
282312
Tensor filter_grad_ = *filter_grad;
283313
filter_grad_.Resize(filter_matrix_shape);
284314
set_zero(context.device_context(), filter_grad, static_cast<T>(0));
285-
315+
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
316+
math::Vol2ColFunctor<Place, T> vol2col;
286317
for (int i = 0; i < batch_size; i++) {
287318
Tensor out_grad_batch =
288319
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
@@ -293,16 +324,18 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
293324
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
294325
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
295326

296-
if (filter_shape_vec.size() == 2) {
297-
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
298-
im2col(context.device_context(), in_slice, col, strides[0],
299-
strides[1], paddings[0], paddings[0], paddings[1],
300-
paddings[1]);
327+
if (!is_expand) {
328+
col.ShareDataWith(in_slice);
329+
col_matrix.ShareDataWith(col);
330+
col_matrix.Resize(col_matrix_shape);
331+
} else if (filter_shape_vec.size() == 2) {
332+
im2col(context.device_context(), in_slice, dilations, strides,
333+
std::vector<int>{paddings[0], paddings[1], paddings[0],
334+
paddings[1]},
335+
&col);
301336
} else if (filter_shape_vec.size() == 3) {
302-
math::Vol2ColFunctor<Place, T> vol2col;
303-
vol2col(context.device_context(), in_slice, col, strides[0],
304-
strides[1], strides[2], paddings[0], paddings[1],
305-
paddings[2]);
337+
vol2col(context.device_context(), in_slice, dilations, strides,
338+
paddings, &col);
306339
}
307340

308341
// gemm

0 commit comments

Comments
 (0)