Skip to content

Commit 660df12

Browse files
committed
enable padding!=0 and fill height padding with 0
1 parent d8e00fa commit 660df12

File tree

1 file changed

+49
-15
lines changed

1 file changed

+49
-15
lines changed

paddle/fluid/operators/math/im2col.cc

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,63 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
4848
const T* im_data = im.data<T>();
4949
T* col_data = col->data<T>();
5050
// TODO(TJ): change me to template
51-
// further optimaze:
52-
// 1. padding != 1
53-
// 2. could also support stride_h != 1
51+
// further optimize: padding == 1 need special
5452
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
55-
dilation[1] == 1 && padding[0] == 0 && padding[1] == 0) {
53+
dilation[1] == 1) {
5654
int col_matrix_width = output_width * output_height;
5755
int im_size = im_height * im_width;
58-
size_t copy_size = sizeof(T) * output_width;
59-
for (int oh = 0; oh < output_height; ++oh) {
60-
const T* im_data_start = im_data + oh * im_width;
61-
T* dst_data = col_data + oh * output_width;
62-
for (int ic = 0; ic < im_channels; ++ic) {
63-
const T* src_data = im_data_start + ic * im_size;
64-
for (int kh = 0; kh < filter_height; ++kh) {
56+
if (padding[0] == 0 && padding[1] == 0) {
57+
size_t copy_size = sizeof(T) * output_width;
58+
for (int oh = 0; oh < output_height; ++oh) {
59+
const T* im_data_start = im_data + oh * im_width;
60+
T* dst_data = col_data + oh * output_width;
61+
for (int ic = 0; ic < im_channels; ++ic) {
62+
const T* src_data = im_data_start + ic * im_size;
63+
for (int kh = 0; kh < filter_height; ++kh) {
64+
for (int kw = 0; kw < filter_width; ++kw) {
65+
std::memcpy(dst_data, src_data + kw, copy_size);
66+
dst_data = dst_data + col_matrix_width;
67+
}
68+
src_data = src_data + im_width;
69+
}
70+
}
71+
}
72+
return;
73+
} else {
74+
int plh = padding[0];
75+
// int plw = padding[1];
76+
int prh =
77+
(output_height - 1) * stride[0] + filter_height - im_height - plh;
78+
// int prw = (output_width - 1) * stride[1] + filter_width - im_width -
79+
// plw;
80+
81+
// fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
82+
// TODO(TJ): reuse sizes
83+
assert(plh == prh); // because stride_h == 1
84+
for (int ph = 0; ph < plh; ++ph) {
85+
size_t sz = sizeof(T) * output_width * (plh - ph);
86+
T* col_start_l = col_data + ph * filter_width * col_matrix_width;
87+
T* col_start_r =
88+
col_data +
89+
(filter_width - ph - 1) * filter_width * col_matrix_width +
90+
col_matrix_width - output_width * (plh - ph);
91+
for (int ic = 0; ic < im_channels; ++ic) {
92+
T* dst_data_l =
93+
col_start_l +
94+
ic * filter_width * filter_height * col_matrix_width;
95+
T* dst_data_r =
96+
col_start_r +
97+
ic * filter_width * filter_height * col_matrix_width;
6598
for (int kw = 0; kw < filter_width; ++kw) {
66-
std::memcpy(dst_data, src_data + kw, copy_size);
67-
dst_data = dst_data + col_matrix_width;
99+
std::memset(dst_data_l, 0, sz);
100+
std::memset(dst_data_r, 0, sz);
101+
dst_data_l = dst_data_l + col_matrix_width;
102+
dst_data_r = dst_data_r + col_matrix_width;
68103
}
69-
src_data = src_data + im_width;
70104
}
71105
}
106+
return;
72107
}
73-
return;
74108
}
75109

76110
for (int c = 0; c < channels_col; ++c) {

0 commit comments

Comments
 (0)