Skip to content

Commit e3131e2

Browse files
committed
enable width padding
1 parent 92518c5 commit e3131e2

File tree

1 file changed

+61
-3
lines changed

1 file changed

+61
-3
lines changed

paddle/fluid/operators/math/im2col.cc

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
7272
return;
7373
} else {
7474
int plh = padding[0];
75-
// int plw = padding[1];
75+
int plw = padding[1];
7676
int prh =
7777
(output_height - 1) * stride[0] + filter_height - im_height - plh;
78-
// int prw = (output_width - 1) * stride[1] + filter_width - im_width -
79-
// plw;
78+
int prw =
79+
(output_width - 1) * stride[1] + filter_width - im_width - plw;
8080

8181
// fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
8282
// TODO(TJ): refine ph*xxx
@@ -100,6 +100,64 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
100100
}
101101
}
102102
}
103+
104+
// fill width padding
105+
assert(plw == prw); // because stride_w == 1
106+
if (plw == 1) {
107+
auto pad = static_cast<T>(0); // padding zero
108+
for (int ic = 0; ic < im_channels; ++ic) {
109+
// TODO(TJ): use add and resue stride
110+
T* dst_data_ic = col_data + ic * col_block_ic;
111+
for (int kh = 0; kh < filter_height; ++kh) {
112+
T* dst_data_kh = dst_data_ic + kh * col_block_fh;
113+
for (T* dst_data :
114+
{dst_data_kh, dst_data_kh +
115+
(filter_width - prw) * col_matrix_width +
116+
output_width - 1}) {
117+
// TODO(TJ): from plh, saving repeated assignment
118+
for (int oh = 0; oh < output_height; ++oh) {
119+
*dst_data = pad;
120+
dst_data = dst_data + output_width;
121+
}
122+
}
123+
}
124+
}
125+
} else {
126+
// padding_size > 1
127+
for (int ic = 0; ic < im_channels; ++ic) {
128+
// TODO(TJ): use add and resue stride
129+
T* dst_data_ic =
130+
col_data + ic * filter_width * filter_height * col_matrix_width;
131+
for (int kh = 0; kh < filter_height; ++kh) {
132+
T* dst_data_kh =
133+
dst_data_ic + kh * filter_width * col_matrix_width;
134+
for (int kw = 0; kw < plw; ++kw) {
135+
// TODO(TJ): reuse array outside this for
136+
size_t sz = sizeof(T) * (plw - kw);
137+
T* dst_data = dst_data_kh + kw * col_matrix_width;
138+
// TODO(TJ): from plh, saving repeated assignment
139+
for (int oh = 0; oh < output_height; ++oh) {
140+
std::memset(dst_data, 0, sz);
141+
dst_data = dst_data + output_width;
142+
}
143+
}
144+
// TODO(TJ): use reverse to save cache
145+
for (int kw = 0; kw < prw; ++kw) {
146+
// TODO(TJ): reuse array outside this for
147+
auto num = (prw - kw);
148+
size_t sz = sizeof(T) * num;
149+
T* dst_data = dst_data_kh +
150+
(filter_width - 1 - kw) * col_matrix_width +
151+
output_width - num;
152+
// TODO(TJ): from plh, saving repeated assignment
153+
for (int oh = 0; oh < output_height; ++oh) {
154+
std::memset(dst_data, 0, sz);
155+
dst_data = dst_data + output_width;
156+
}
157+
}
158+
}
159+
}
160+
}
103161
return;
104162
}
105163
}

0 commit comments

Comments
 (0)