Skip to content

Commit aae994f

Browse files
committed
refine im2col no padding
1 parent 03d70c1 commit aae994f

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

paddle/fluid/operators/math/im2col.cc

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,46 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
4040
int im_width = im.dims()[2];
4141
int filter_height = col->dims()[1];
4242
int filter_width = col->dims()[2];
43-
int col_height = col->dims()[3];
44-
int col_width = col->dims()[4];
43+
int output_height = col->dims()[3];
44+
int output_width = col->dims()[4];
4545

4646
int channels_col = im_channels * filter_height * filter_width;
4747

4848
const T* im_data = im.data<T>();
4949
T* col_data = col->data<T>();
50+
// TODO(TJ): change me to template
51+
// further optimaze:
52+
// 1. padding != 1
53+
// 2. could also support stride_h != 1
54+
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
55+
dilation[1] == 1 && padding[0] == 0 && padding[1] == 0) {
56+
int col_matrix_width = output_width * output_height;
57+
for (int oh = 0; oh < output_height; ++oh) {
58+
const T* im_data_start = im_data + oh * im_width;
59+
T* dst_data = col_data + oh * output_width;
60+
for (int ic = 0; ic < im_channels; ++ic) {
61+
const T* src_data = im_data_start + ic * im_height * im_width;
62+
for (int kh = 0; kh < filter_height; ++kh) {
63+
for (int kw = 0; kw < filter_width; ++kw) {
64+
std::memcpy(dst_data, src_data + kw, sizeof(T) * output_width);
65+
dst_data = dst_data + col_matrix_width;
66+
}
67+
src_data = src_data + im_width;
68+
}
69+
}
70+
}
71+
return;
72+
}
73+
5074
for (int c = 0; c < channels_col; ++c) {
5175
int w_offset = c % filter_width;
5276
int h_offset = (c / filter_width) % filter_height;
5377
int c_im = c / (filter_width * filter_height);
54-
for (int h = 0; h < col_height; ++h) {
78+
for (int h = 0; h < output_height; ++h) {
5579
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
56-
for (int w = 0; w < col_width; ++w) {
80+
for (int w = 0; w < output_width; ++w) {
5781
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
58-
int col_idx = (c * col_height + h) * col_width + w;
82+
int col_idx = (c * output_height + h) * output_width + w;
5983
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
6084

6185
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||

0 commit comments

Comments
 (0)