Skip to content

Commit 23bf6b2

Browse files
authored
Merge pull request #4887 from chengduoZH/fix_im2col_kocf_for_sequence
Add up, down, left and right padding for im2col.
2 parents 8fdc315 + 09662da commit 23bf6b2

File tree

5 files changed

+225
-70
lines changed

5 files changed

+225
-70
lines changed

paddle/operators/conv2d_op.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class GemmConv2DKernel : public framework::OpKernel<T> {
114114
// im2col
115115
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
116116
im2col(context.device_context(), in_slice, col, strides[0], strides[1],
117-
paddings[0], paddings[1]);
117+
paddings[0], paddings[0], paddings[1], paddings[1]);
118118

119119
// gemm
120120
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
@@ -213,7 +213,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
213213
Tensor in_grad_slice =
214214
in_grad_batch.Slice(g * in_step, (g + 1) * in_step);
215215
col2im(context.device_context(), in_grad_slice, col, strides[0],
216-
strides[1], paddings[0], paddings[1]);
216+
strides[1], paddings[0], paddings[0], paddings[1],
217+
paddings[1]);
217218
}
218219
}
219220
}
@@ -235,7 +236,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
235236
out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
236237
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
237238
im2col(context.device_context(), in_slice, col, strides[0],
238-
strides[1], paddings[0], paddings[1]);
239+
strides[1], paddings[0], paddings[0], paddings[1],
240+
paddings[1]);
239241

240242
// gemm
241243
Tensor filter_grad_slice =

paddle/operators/math/im2col.cc

Lines changed: 87 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
2929
public:
3030
void operator()(const platform::DeviceContext& context,
3131
const framework::Tensor& im, framework::Tensor& col,
32-
int stride_height, int stride_width, int padding_height,
33-
int padding_width) {
32+
int stride_height, int stride_width, int padding_up,
33+
int padding_down, int padding_left, int padding_right) {
3434
PADDLE_ENFORCE(im.dims().size() == 3);
3535
PADDLE_ENFORCE(col.dims().size() == 5);
3636

@@ -41,6 +41,22 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
4141
int filter_width = col.dims()[2];
4242
int output_height = col.dims()[3];
4343
int output_width = col.dims()[4];
44+
45+
PADDLE_ENFORCE_EQ(
46+
(input_height + padding_up + padding_down - filter_height) /
47+
stride_height +
48+
1,
49+
output_height,
50+
"Output_height and padding(padding_up, padding_down) are "
51+
"inconsistent.");
52+
PADDLE_ENFORCE_EQ(
53+
(input_width + padding_left + padding_right - filter_width) /
54+
stride_width +
55+
1,
56+
output_width,
57+
"output_width and padding(padding_left, padding_right) are "
58+
"inconsistent.");
59+
4460
int channels_col = input_channels * filter_height * filter_width;
4561

4662
const T* im_data = im.data<T>();
@@ -52,16 +68,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
5268
int c_im = c / filter_width / filter_height;
5369
for (int h = 0; h < output_height; ++h) {
5470
for (int w = 0; w < output_width; ++w) {
55-
int im_row_idx = h * stride_height + h_offset;
56-
int im_col_idx = w * stride_width + w_offset;
57-
if ((im_row_idx - padding_height) < 0 ||
58-
(im_row_idx - padding_height) >= input_height ||
59-
(im_col_idx - padding_width) < 0 ||
60-
(im_col_idx - padding_width) >= input_width) {
71+
int im_row_idx = h * stride_height + h_offset - padding_up;
72+
int im_col_idx = w * stride_width + w_offset - padding_left;
73+
74+
if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 ||
75+
im_col_idx >= input_width) {
6176
col_data[(c * output_height + h) * output_width + w] = T(0);
6277
} else {
63-
im_row_idx += c_im * input_height - padding_height;
64-
im_col_idx -= padding_width;
78+
im_row_idx += c_im * input_height;
6579
col_data[(c * output_height + h) * output_width + w] =
6680
im_data[im_row_idx * input_width + im_col_idx];
6781
}
@@ -82,7 +96,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
8296
public:
8397
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
8498
const framework::Tensor& col, int stride_height,
85-
int stride_width, int padding_height, int padding_width) {
99+
int stride_width, int padding_up, int padding_down,
100+
int padding_left, int padding_right) {
86101
PADDLE_ENFORCE(im.dims().size() == 3);
87102
PADDLE_ENFORCE(col.dims().size() == 5);
88103
int input_channels = im.dims()[0];
@@ -92,6 +107,22 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
92107
int filter_width = col.dims()[2];
93108
int output_height = col.dims()[3];
94109
int output_width = col.dims()[4];
110+
111+
PADDLE_ENFORCE_EQ(
112+
(input_height + padding_up + padding_down - filter_height) /
113+
stride_height +
114+
1,
115+
output_height,
116+
"Output_height and padding(padding_up, padding_down) are "
117+
"inconsistent.");
118+
PADDLE_ENFORCE_EQ(
119+
(input_width + padding_left + padding_right - filter_width) /
120+
stride_width +
121+
1,
122+
output_width,
123+
"output_width and padding(padding_left, padding_right) are "
124+
"inconsistent.");
125+
95126
int channels_col = input_channels * filter_height * filter_width;
96127

97128
T* im_data = im.data<T>();
@@ -103,14 +134,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
103134
int c_im = c / filter_width / filter_height;
104135
for (int h = 0; h < output_height; ++h) {
105136
for (int w = 0; w < output_width; ++w) {
106-
int im_row_idx = h * stride_height + h_offset;
107-
int im_col_idx = w * stride_width + w_offset;
108-
if ((im_row_idx - padding_height) >= 0 &&
109-
(im_row_idx - padding_height) < input_height &&
110-
(im_col_idx - padding_width) >= 0 &&
111-
(im_col_idx - padding_width) < input_width) {
112-
im_row_idx += c_im * input_height - padding_height;
113-
im_col_idx -= padding_width;
137+
int im_row_idx = h * stride_height + h_offset - padding_up;
138+
int im_col_idx = w * stride_width + w_offset - padding_left;
139+
140+
if ((im_row_idx) >= 0 && (im_row_idx) < input_height &&
141+
(im_col_idx) >= 0 && (im_col_idx) < input_width) {
142+
im_row_idx += c_im * input_height;
114143
im_data[im_row_idx * input_width + im_col_idx] +=
115144
col_data[(c * output_height + h) * output_width + w];
116145
}
@@ -140,8 +169,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
140169
public:
141170
void operator()(const platform::DeviceContext& context,
142171
const framework::Tensor& im, framework::Tensor& col,
143-
int stride_height, int stride_width, int padding_height,
144-
int padding_width) {
172+
int stride_height, int stride_width, int padding_up,
173+
int padding_down, int padding_left, int padding_right) {
145174
PADDLE_ENFORCE(im.dims().size() == 3);
146175
PADDLE_ENFORCE(col.dims().size() == 5);
147176
int input_channels = im.dims()[0];
@@ -152,6 +181,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
152181
int output_height = col.dims()[0];
153182
int output_width = col.dims()[1];
154183

184+
PADDLE_ENFORCE_EQ(
185+
(input_height + padding_up + padding_down - filter_height) /
186+
stride_height +
187+
1,
188+
output_height,
189+
"Output_height and padding(padding_up, padding_down) are "
190+
"inconsistent.");
191+
PADDLE_ENFORCE_EQ(
192+
(input_width + padding_left + padding_right - filter_width) /
193+
stride_width +
194+
1,
195+
output_width,
196+
"output_width and padding(padding_left, padding_right) are "
197+
"inconsistent.");
198+
155199
const T* im_data = im.data<T>();
156200
T* col_data = col.data<T>();
157201

@@ -163,10 +207,10 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
163207
for (int filter_col_idx = 0; filter_col_idx < filter_width;
164208
++filter_col_idx) {
165209
int im_row_offset =
166-
col_row_idx * stride_height + filter_row_idx - padding_height;
210+
col_row_idx * stride_height + filter_row_idx - padding_up;
167211
int im_col_offset =
168-
col_col_idx * stride_width + filter_col_idx - padding_width;
169-
int col_offset = (((col_row_idx * output_width + col_col_idx) *
212+
col_col_idx * stride_width + filter_col_idx - padding_left;
213+
int col_offset = ((((col_row_idx)*output_width + col_col_idx) *
170214
input_channels +
171215
channel) *
172216
filter_height +
@@ -201,7 +245,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
201245
public:
202246
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
203247
const framework::Tensor& col, int stride_height,
204-
int stride_width, int padding_height, int padding_width) {
248+
int stride_width, int padding_up, int padding_down,
249+
int padding_left, int padding_right) {
205250
PADDLE_ENFORCE(im.dims().size() == 3);
206251
PADDLE_ENFORCE(col.dims().size() == 5);
207252
int input_channels = im.dims()[0];
@@ -212,6 +257,21 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
212257
int output_height = col.dims()[0];
213258
int output_width = col.dims()[1];
214259

260+
PADDLE_ENFORCE_EQ(
261+
(input_height + padding_up + padding_down - filter_height) /
262+
stride_height +
263+
1,
264+
output_height,
265+
"Output_height and padding(padding_up, padding_down) are "
266+
"inconsistent.");
267+
PADDLE_ENFORCE_EQ(
268+
(input_width + padding_left + padding_right - filter_width) /
269+
stride_width +
270+
1,
271+
output_width,
272+
"output_width and padding(padding_left, padding_right) are "
273+
"inconsistent.");
274+
215275
T* im_data = im.data<T>();
216276
const T* col_data = col.data<T>();
217277

@@ -223,9 +283,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
223283
for (int filter_col_idx = 0; filter_col_idx < filter_width;
224284
++filter_col_idx) {
225285
int im_row_offset =
226-
col_row_idx * stride_height + filter_row_idx - padding_height;
286+
col_row_idx * stride_height + filter_row_idx - padding_up;
227287
int im_col_offset =
228-
col_col_idx * stride_width + filter_col_idx - padding_width;
288+
col_col_idx * stride_width + filter_col_idx - padding_left;
229289
int col_offset = (((col_row_idx * output_width + col_col_idx) *
230290
input_channels +
231291
channel) *

paddle/operators/math/im2col.cu

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
6666
public:
6767
void operator()(const platform::DeviceContext& context,
6868
const framework::Tensor& im, framework::Tensor& col,
69-
int stride_height, int stride_width, int padding_height,
70-
int padding_width) {
69+
int stride_height, int stride_width, int padding_up,
70+
int padding_down, int padding_left, int padding_right) {
7171
PADDLE_ENFORCE(im.dims().size() == 3);
7272
PADDLE_ENFORCE(col.dims().size() == 5);
7373

@@ -79,6 +79,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
7979
int output_height = col.dims()[3];
8080
int output_width = col.dims()[4];
8181

82+
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
83+
stride_height +
84+
1 ==
85+
output_height);
86+
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
87+
stride_width +
88+
1 ==
89+
output_width);
90+
8291
int num_outputs = input_channels * output_height * output_width;
8392
int blocks = (num_outputs + 1024 - 1) / 1024;
8493
int block_x = 512;
@@ -89,8 +98,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
8998
reinterpret_cast<const platform::CUDADeviceContext&>(context)
9099
.stream()>>>(
91100
im.data<T>(), num_outputs, input_height, input_width, filter_height,
92-
filter_width, stride_height, stride_width, padding_height,
93-
padding_width, output_height, output_width, col.data<T>());
101+
filter_width, stride_height, stride_width, padding_up, padding_left,
102+
output_height, output_width, col.data<T>());
94103
}
95104
};
96105

@@ -152,7 +161,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
152161
public:
153162
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
154163
const framework::Tensor& col, int stride_height,
155-
int stride_width, int padding_height, int padding_width) {
164+
int stride_width, int padding_up, int padding_down,
165+
int padding_left, int padding_right) {
156166
PADDLE_ENFORCE(im.dims().size() == 3);
157167
PADDLE_ENFORCE(col.dims().size() == 5);
158168

@@ -164,8 +174,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
164174
int output_height = col.dims()[3];
165175
int output_width = col.dims()[4];
166176

167-
size_t num_kernels = input_channels * (input_height + 2 * padding_height) *
168-
(input_width + 2 * padding_width);
177+
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
178+
stride_height +
179+
1 ==
180+
output_height);
181+
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
182+
stride_width +
183+
1 ==
184+
output_width);
185+
186+
size_t num_kernels = input_channels *
187+
(input_height + padding_up + padding_down) *
188+
(input_width + padding_left + padding_right);
169189

170190
size_t blocks = (num_kernels + 1024 - 1) / 1024;
171191
size_t block_x = 512;
@@ -178,10 +198,10 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
178198
col2im<T><<<grid, threads, 0,
179199
reinterpret_cast<const platform::CUDADeviceContext&>(context)
180200
.stream()>>>(
181-
num_kernels, col.data<T>(), input_height + 2 * padding_height,
182-
input_width + 2 * padding_width, input_channels, filter_height,
183-
filter_width, stride_height, stride_width, padding_height,
184-
padding_width, output_height, output_width, im.data<T>());
201+
num_kernels, col.data<T>(), input_height + padding_up + padding_down,
202+
input_width + padding_left + padding_left, input_channels,
203+
filter_height, filter_width, stride_height, stride_width, padding_up,
204+
padding_left, output_height, output_width, im.data<T>());
185205
}
186206
};
187207

@@ -238,8 +258,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
238258
public:
239259
void operator()(const platform::DeviceContext& context,
240260
const framework::Tensor& im, framework::Tensor& col,
241-
int stride_height, int stride_width, int padding_height,
242-
int padding_width) {
261+
int stride_height, int stride_width, int padding_up,
262+
int padding_down, int padding_left, int padding_right) {
243263
PADDLE_ENFORCE(im.dims().size() == 3);
244264
PADDLE_ENFORCE(col.dims().size() == 5);
245265
int input_channels = im.dims()[0];
@@ -250,6 +270,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
250270
int output_height = col.dims()[0];
251271
int output_width = col.dims()[1];
252272

273+
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
274+
stride_height +
275+
1 ==
276+
output_height);
277+
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
278+
stride_width +
279+
1 ==
280+
output_width);
281+
253282
int block_dim_x = 0;
254283
int block_dim_y = 0;
255284
if (filter_height <= 4 && filter_width <= 4) {
@@ -274,8 +303,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
274303
reinterpret_cast<const platform::CUDADeviceContext&>(context)
275304
.stream()>>>(
276305
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
277-
filter_height, filter_width, stride_height, stride_width,
278-
padding_height, padding_width, output_height, output_width);
306+
filter_height, filter_width, stride_height, stride_width, padding_up,
307+
padding_left, output_height, output_width);
279308
}
280309
};
281310

@@ -322,7 +351,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
322351
public:
323352
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
324353
const framework::Tensor& col, int stride_height,
325-
int stride_width, int padding_height, int padding_width) {
354+
int stride_width, int padding_up, int padding_down,
355+
int padding_left, int padding_right) {
326356
PADDLE_ENFORCE(im.dims().size() == 3);
327357
PADDLE_ENFORCE(col.dims().size() == 5);
328358
int input_channels = im.dims()[0];
@@ -333,6 +363,15 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
333363
int output_height = col.dims()[0];
334364
int output_width = col.dims()[1];
335365

366+
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
367+
stride_height +
368+
1 ==
369+
output_height);
370+
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
371+
stride_width +
372+
1 ==
373+
output_width);
374+
336375
int block_dim_x = 0;
337376
int block_dim_y = 0;
338377
if (filter_height <= 4 && filter_width <= 4) {
@@ -357,8 +396,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
357396
reinterpret_cast<const platform::CUDADeviceContext&>(context)
358397
.stream()>>>(
359398
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
360-
filter_height, filter_width, stride_height, stride_width,
361-
padding_height, padding_width, output_height, output_width);
399+
filter_height, filter_width, stride_height, stride_width, padding_up,
400+
padding_left, output_height, output_width);
362401
}
363402
};
364403

0 commit comments

Comments
 (0)