@@ -126,11 +126,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
126
126
// padding_size > 1
127
127
for (int ic = 0 ; ic < im_channels; ++ic) {
128
128
// TODO(TJ): use add and resue stride
129
- T* dst_data_ic =
130
- col_data + ic * filter_width * filter_height * col_matrix_width;
129
+ T* dst_data_ic = col_data + ic * col_block_ic;
131
130
for (int kh = 0 ; kh < filter_height; ++kh) {
132
- T* dst_data_kh =
133
- dst_data_ic + kh * filter_width * col_matrix_width;
131
+ T* dst_data_kh = dst_data_ic + kh * col_block_fh;
134
132
for (int kw = 0 ; kw < plw; ++kw) {
135
133
// TODO(TJ): reuse array outside this for
136
134
size_t sz = sizeof (T) * (plw - kw);
@@ -158,6 +156,67 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
158
156
}
159
157
}
160
158
}
159
+
160
+ // fill im_data
161
+ // padding cover two cases:
162
+ // 1. kw > 2*pw: kw = 3, pw = 1
163
+ // 0 x x x x ... x x x x 0
164
+ // 1 1 1 1 1 1
165
+ // ==>
166
+ // 0 x ... x x
167
+ // x x ... x x
168
+ // x x ... x 0
169
+ // 2. kw < 2*pw: kw = 3, pw = 2
170
+ // 0 0 x x x ... x x x 0 0
171
+ // 1 1 1 1 1 1
172
+ // ==>
173
+ // 0 0 x ... x x x
174
+ // 0 x x ... x x 0
175
+ // x x x ... x 0 0
176
+
177
+ // TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) *
178
+ // (output_width-1)}
179
+ // length of copy_size is equal kw.
180
+ if (plw + prw < filter_width) {
181
+ for (int oh = 0 ; oh < output_height; ++oh) {
182
+ const T* im_data_start =
183
+ im_data + (oh - plh > 0 ? oh - plh : 0 ) * im_width;
184
+ T* dst_data = col_data + oh * output_width;
185
+ for (int ic = 0 ; ic < im_channels; ++ic) {
186
+ const T* src_data = im_data_start + ic * im_size;
187
+ for (int kh = 0 ; kh < filter_height; ++kh) {
188
+ if ((oh < plh && kh < plh) ||
189
+ (oh > (output_height - prh - 1 ) &&
190
+ kh > (filter_height - prh - 1 ))) {
191
+ dst_data = dst_data + filter_width * col_matrix_width;
192
+ continue ;
193
+ }
194
+ // TODO(TJ): reuse plw-kw outside this for
195
+ // try to unify
196
+ for (int kw = 0 ; kw < plw; ++kw) {
197
+ std::memcpy (dst_data + (plw - kw), src_data,
198
+ sizeof (T) * (output_width - (plw - kw)));
199
+ dst_data = dst_data + col_matrix_width;
200
+ }
201
+ for (int kw = plw; kw < filter_width - prw; ++kw) {
202
+ std::memcpy (dst_data, src_data + (kw - plw),
203
+ sizeof (T) * output_width);
204
+ dst_data = dst_data + col_matrix_width;
205
+ }
206
+ int i = 1 ;
207
+ for (int kw = filter_width - prw; kw < filter_width;
208
+ ++kw, ++i) {
209
+ std::memcpy (dst_data, src_data + (kw - plw),
210
+ sizeof (T) * (output_width - i));
211
+ dst_data = dst_data + col_matrix_width;
212
+ }
213
+ src_data = src_data + im_width;
214
+ }
215
+ }
216
+ }
217
+ } else {
218
+ LOG (FATAL) << " Not implement yet" ;
219
+ }
161
220
return ;
162
221
}
163
222
}
0 commit comments