@@ -14,6 +14,7 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/math/im2col.h"
16
16
#include < vector>
17
+ #include " paddle/fluid/operators/math/im2col_cfo_cpu.h"
17
18
18
19
namespace paddle {
19
20
namespace operators {
@@ -35,210 +36,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
35
36
PADDLE_ENFORCE (im.dims ().size () == 3 );
36
37
PADDLE_ENFORCE (col->dims ().size () == 5 );
37
38
38
- int im_channels = im.dims ()[0 ];
39
- int im_height = im.dims ()[1 ];
40
- int im_width = im.dims ()[2 ];
41
- int filter_height = col->dims ()[1 ];
42
- int filter_width = col->dims ()[2 ];
43
- int output_height = col->dims ()[3 ];
44
- int output_width = col->dims ()[4 ];
45
-
46
- int channels_col = im_channels * filter_height * filter_width;
47
-
48
- const T* im_data = im.data <T>();
49
- T* col_data = col->data <T>();
50
- // TODO(TJ): change me to template
51
- // further optimize: padding == 1 need special
52
39
if (stride[0 ] == 1 && stride[1 ] == 1 && dilation[0 ] == 1 &&
53
40
dilation[1 ] == 1 ) {
54
- int col_matrix_width = output_width * output_height;
55
- int im_size = im_height * im_width;
56
41
if (padding[0 ] == 0 && padding[1 ] == 0 ) {
57
- size_t copy_size = sizeof (T) * output_width;
58
- for (int oh = 0 ; oh < output_height; ++oh) {
59
- const T* im_data_start = im_data + oh * im_width;
60
- T* dst_data = col_data + oh * output_width;
61
- for (int ic = 0 ; ic < im_channels; ++ic) {
62
- const T* src_data = im_data_start + ic * im_size;
63
- for (int kh = 0 ; kh < filter_height; ++kh) {
64
- for (int kw = 0 ; kw < filter_width; ++kw) {
65
- std::memcpy (dst_data, src_data + kw, copy_size);
66
- dst_data = dst_data + col_matrix_width;
67
- }
68
- src_data = src_data + im_width;
69
- }
70
- }
71
- }
72
- return ;
42
+ im2col_sh1sw1dh1dw1ph0pw0<T>(im, col);
73
43
} else {
74
- int plh = padding[0 ];
75
- int plw = padding[1 ];
76
- int prh =
77
- (output_height - 1 ) * stride[0 ] + filter_height - im_height - plh;
78
- int prw =
79
- (output_width - 1 ) * stride[1 ] + filter_width - im_width - plw;
80
-
81
- // fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
82
- // TODO(TJ): refine ph*xxx
83
- assert (plh == prh); // because stride_h == 1
84
- int col_block_fh = filter_width * col_matrix_width; // fw*oh*ow
85
- int col_block_ic = filter_height * col_block_fh; // fh*fw*oh*ow
86
- for (int ph = 0 ; ph < plh; ++ph) {
87
- int sz = output_width * (plh - ph);
88
- size_t copy_sz = sizeof (T) * sz;
89
- T* col_start_l = col_data + ph * col_block_fh;
90
- T* col_start_r = col_data + (filter_height - ph - 1 ) * col_block_fh +
91
- col_matrix_width - sz;
92
- for (int ic = 0 ; ic < im_channels; ++ic) {
93
- T* dst_data_l = col_start_l + ic * col_block_ic;
94
- T* dst_data_r = col_start_r + ic * col_block_ic;
95
- for (int kw = 0 ; kw < filter_width; ++kw) {
96
- std::memset (dst_data_l, 0 , copy_sz);
97
- std::memset (dst_data_r, 0 , copy_sz);
98
- dst_data_l = dst_data_l + col_matrix_width;
99
- dst_data_r = dst_data_r + col_matrix_width;
100
- }
101
- }
102
- }
103
-
104
- // fill width padding
105
- assert (plw == prw); // because stride_w == 1
106
- if (plw == 1 ) {
107
- auto pad = static_cast <T>(0 ); // padding zero
108
- for (int ic = 0 ; ic < im_channels; ++ic) {
109
- // TODO(TJ): use add and resue stride
110
- T* dst_data_ic = col_data + ic * col_block_ic;
111
- for (int kh = 0 ; kh < filter_height; ++kh) {
112
- T* dst_data_kh = dst_data_ic + kh * col_block_fh;
113
- for (T* dst_data :
114
- {dst_data_kh, dst_data_kh +
115
- (filter_width - prw) * col_matrix_width +
116
- output_width - 1 }) {
117
- // TODO(TJ): from plh, saving repeated assignment
118
- for (int oh = 0 ; oh < output_height; ++oh) {
119
- *dst_data = pad;
120
- dst_data = dst_data + output_width;
121
- }
122
- }
123
- }
124
- }
125
- } else {
126
- // padding_size > 1
127
- for (int ic = 0 ; ic < im_channels; ++ic) {
128
- // TODO(TJ): use add and resue stride
129
- T* dst_data_ic = col_data + ic * col_block_ic;
130
- for (int kh = 0 ; kh < filter_height; ++kh) {
131
- T* dst_data_kh = dst_data_ic + kh * col_block_fh;
132
- for (int kw = 0 ; kw < plw; ++kw) {
133
- // TODO(TJ): reuse array outside this for
134
- size_t sz = sizeof (T) * (plw - kw);
135
- T* dst_data = dst_data_kh + kw * col_matrix_width;
136
- // TODO(TJ): from plh, saving repeated assignment
137
- for (int oh = 0 ; oh < output_height; ++oh) {
138
- std::memset (dst_data, 0 , sz);
139
- dst_data = dst_data + output_width;
140
- }
141
- }
142
- // TODO(TJ): use reverse to save cache
143
- for (int kw = 0 ; kw < prw; ++kw) {
144
- // TODO(TJ): reuse array outside this for
145
- auto num = (prw - kw);
146
- size_t sz = sizeof (T) * num;
147
- T* dst_data = dst_data_kh +
148
- (filter_width - 1 - kw) * col_matrix_width +
149
- output_width - num;
150
- // TODO(TJ): from plh, saving repeated assignment
151
- for (int oh = 0 ; oh < output_height; ++oh) {
152
- std::memset (dst_data, 0 , sz);
153
- dst_data = dst_data + output_width;
154
- }
155
- }
156
- }
157
- }
158
- }
159
-
160
- // fill im_data
161
- // padding cover two cases:
162
- // 1. kw > 2*pw: kw = 3, pw = 1
163
- // 0 x x x x ... x x x x 0
164
- // 1 1 1 1 1 1
165
- // ==>
166
- // 0 x ... x x
167
- // x x ... x x
168
- // x x ... x 0
169
- // 2. kw < 2*pw: kw = 3, pw = 2
170
- // 0 0 x x x ... x x x 0 0
171
- // 1 1 1 1 1 1
172
- // ==>
173
- // 0 0 x ... x x x
174
- // 0 x x ... x x 0
175
- // x x x ... x 0 0
176
-
177
- // TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) *
178
- // (output_width-1)}
179
- // length of copy_size is equal kw.
180
- if (plw + prw < filter_width) {
181
- for (int oh = 0 ; oh < output_height; ++oh) {
182
- const T* im_data_start =
183
- im_data + (oh - plh > 0 ? oh - plh : 0 ) * im_width;
184
- T* dst_data = col_data + oh * output_width;
185
- for (int ic = 0 ; ic < im_channels; ++ic) {
186
- const T* src_data = im_data_start + ic * im_size;
187
- for (int kh = 0 ; kh < filter_height; ++kh) {
188
- if ((oh < plh && kh < plh) ||
189
- (oh > (output_height - prh - 1 ) &&
190
- kh > (filter_height - prh - 1 ))) {
191
- dst_data = dst_data + filter_width * col_matrix_width;
192
- continue ;
193
- }
194
- // TODO(TJ): reuse plw-kw outside this for
195
- // try to unify
196
- for (int kw = 0 ; kw < plw; ++kw) {
197
- std::memcpy (dst_data + (plw - kw), src_data,
198
- sizeof (T) * (output_width - (plw - kw)));
199
- dst_data = dst_data + col_matrix_width;
200
- }
201
- for (int kw = plw; kw < filter_width - prw; ++kw) {
202
- std::memcpy (dst_data, src_data + (kw - plw),
203
- sizeof (T) * output_width);
204
- dst_data = dst_data + col_matrix_width;
205
- }
206
- int i = 1 ;
207
- for (int kw = filter_width - prw; kw < filter_width;
208
- ++kw, ++i) {
209
- std::memcpy (dst_data, src_data + (kw - plw),
210
- sizeof (T) * (output_width - i));
211
- dst_data = dst_data + col_matrix_width;
212
- }
213
- src_data = src_data + im_width;
214
- }
215
- }
216
- }
217
- } else {
218
- LOG (FATAL) << " Not implement yet" ;
219
- }
220
- return ;
221
- }
222
- }
223
-
224
- for (int c = 0 ; c < channels_col; ++c) {
225
- int w_offset = c % filter_width;
226
- int h_offset = (c / filter_width) % filter_height;
227
- int c_im = c / (filter_width * filter_height);
228
- for (int h = 0 ; h < output_height; ++h) {
229
- int im_row_idx = h * stride[0 ] - padding[0 ] + h_offset * dilation[0 ];
230
- for (int w = 0 ; w < output_width; ++w) {
231
- int im_col_idx = w * stride[1 ] - padding[1 ] + w_offset * dilation[1 ];
232
- int col_idx = (c * output_height + h) * output_width + w;
233
- int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
234
-
235
- col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
236
- im_col_idx < 0 || im_col_idx >= im_width)
237
- ? static_cast <T>(0 )
238
- : im_data[im_idx];
239
- }
44
+ im2col_sh1sw1dh1dw1<T>(im, padding, col);
240
45
}
46
+ return ;
241
47
}
48
+ im2col_common<T>(im, dilation, stride, padding, col);
242
49
}
243
50
};
244
51
0 commit comments