Skip to content

Commit 507c143

Browse files
committed
im2col cfo cpu code clean
1 parent 4eeed0b commit 507c143

File tree

2 files changed

+270
-198
lines changed

2 files changed

+270
-198
lines changed

paddle/fluid/operators/math/im2col.cc

Lines changed: 5 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/im2col.h"
1616
#include <vector>
17+
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
1718

1819
namespace paddle {
1920
namespace operators {
@@ -35,210 +36,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
3536
PADDLE_ENFORCE(im.dims().size() == 3);
3637
PADDLE_ENFORCE(col->dims().size() == 5);
3738

38-
int im_channels = im.dims()[0];
39-
int im_height = im.dims()[1];
40-
int im_width = im.dims()[2];
41-
int filter_height = col->dims()[1];
42-
int filter_width = col->dims()[2];
43-
int output_height = col->dims()[3];
44-
int output_width = col->dims()[4];
45-
46-
int channels_col = im_channels * filter_height * filter_width;
47-
48-
const T* im_data = im.data<T>();
49-
T* col_data = col->data<T>();
50-
// TODO(TJ): change me to template
51-
// further optimize: padding == 1 need special
5239
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
5340
dilation[1] == 1) {
54-
int col_matrix_width = output_width * output_height;
55-
int im_size = im_height * im_width;
5641
if (padding[0] == 0 && padding[1] == 0) {
57-
size_t copy_size = sizeof(T) * output_width;
58-
for (int oh = 0; oh < output_height; ++oh) {
59-
const T* im_data_start = im_data + oh * im_width;
60-
T* dst_data = col_data + oh * output_width;
61-
for (int ic = 0; ic < im_channels; ++ic) {
62-
const T* src_data = im_data_start + ic * im_size;
63-
for (int kh = 0; kh < filter_height; ++kh) {
64-
for (int kw = 0; kw < filter_width; ++kw) {
65-
std::memcpy(dst_data, src_data + kw, copy_size);
66-
dst_data = dst_data + col_matrix_width;
67-
}
68-
src_data = src_data + im_width;
69-
}
70-
}
71-
}
72-
return;
42+
im2col_sh1sw1dh1dw1ph0pw0<T>(im, col);
7343
} else {
74-
int plh = padding[0];
75-
int plw = padding[1];
76-
int prh =
77-
(output_height - 1) * stride[0] + filter_height - im_height - plh;
78-
int prw =
79-
(output_width - 1) * stride[1] + filter_width - im_width - plw;
80-
81-
// fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
82-
// TODO(TJ): refine ph*xxx
83-
assert(plh == prh); // because stride_h == 1
84-
int col_block_fh = filter_width * col_matrix_width; // fw*oh*ow
85-
int col_block_ic = filter_height * col_block_fh; // fh*fw*oh*ow
86-
for (int ph = 0; ph < plh; ++ph) {
87-
int sz = output_width * (plh - ph);
88-
size_t copy_sz = sizeof(T) * sz;
89-
T* col_start_l = col_data + ph * col_block_fh;
90-
T* col_start_r = col_data + (filter_height - ph - 1) * col_block_fh +
91-
col_matrix_width - sz;
92-
for (int ic = 0; ic < im_channels; ++ic) {
93-
T* dst_data_l = col_start_l + ic * col_block_ic;
94-
T* dst_data_r = col_start_r + ic * col_block_ic;
95-
for (int kw = 0; kw < filter_width; ++kw) {
96-
std::memset(dst_data_l, 0, copy_sz);
97-
std::memset(dst_data_r, 0, copy_sz);
98-
dst_data_l = dst_data_l + col_matrix_width;
99-
dst_data_r = dst_data_r + col_matrix_width;
100-
}
101-
}
102-
}
103-
104-
// fill width padding
105-
assert(plw == prw); // because stride_w == 1
106-
if (plw == 1) {
107-
auto pad = static_cast<T>(0); // padding zero
108-
for (int ic = 0; ic < im_channels; ++ic) {
109-
// TODO(TJ): use add and resue stride
110-
T* dst_data_ic = col_data + ic * col_block_ic;
111-
for (int kh = 0; kh < filter_height; ++kh) {
112-
T* dst_data_kh = dst_data_ic + kh * col_block_fh;
113-
for (T* dst_data :
114-
{dst_data_kh, dst_data_kh +
115-
(filter_width - prw) * col_matrix_width +
116-
output_width - 1}) {
117-
// TODO(TJ): from plh, saving repeated assignment
118-
for (int oh = 0; oh < output_height; ++oh) {
119-
*dst_data = pad;
120-
dst_data = dst_data + output_width;
121-
}
122-
}
123-
}
124-
}
125-
} else {
126-
// padding_size > 1
127-
for (int ic = 0; ic < im_channels; ++ic) {
128-
// TODO(TJ): use add and resue stride
129-
T* dst_data_ic = col_data + ic * col_block_ic;
130-
for (int kh = 0; kh < filter_height; ++kh) {
131-
T* dst_data_kh = dst_data_ic + kh * col_block_fh;
132-
for (int kw = 0; kw < plw; ++kw) {
133-
// TODO(TJ): reuse array outside this for
134-
size_t sz = sizeof(T) * (plw - kw);
135-
T* dst_data = dst_data_kh + kw * col_matrix_width;
136-
// TODO(TJ): from plh, saving repeated assignment
137-
for (int oh = 0; oh < output_height; ++oh) {
138-
std::memset(dst_data, 0, sz);
139-
dst_data = dst_data + output_width;
140-
}
141-
}
142-
// TODO(TJ): use reverse to save cache
143-
for (int kw = 0; kw < prw; ++kw) {
144-
// TODO(TJ): reuse array outside this for
145-
auto num = (prw - kw);
146-
size_t sz = sizeof(T) * num;
147-
T* dst_data = dst_data_kh +
148-
(filter_width - 1 - kw) * col_matrix_width +
149-
output_width - num;
150-
// TODO(TJ): from plh, saving repeated assignment
151-
for (int oh = 0; oh < output_height; ++oh) {
152-
std::memset(dst_data, 0, sz);
153-
dst_data = dst_data + output_width;
154-
}
155-
}
156-
}
157-
}
158-
}
159-
160-
// fill im_data
161-
// padding cover two cases:
162-
// 1. kw > 2*pw: kw = 3, pw = 1
163-
// 0 x x x x ... x x x x 0
164-
// 1 1 1 1 1 1
165-
// ==>
166-
// 0 x ... x x
167-
// x x ... x x
168-
// x x ... x 0
169-
// 2. kw < 2*pw: kw = 3, pw = 2
170-
// 0 0 x x x ... x x x 0 0
171-
// 1 1 1 1 1 1
172-
// ==>
173-
// 0 0 x ... x x x
174-
// 0 x x ... x x 0
175-
// x x x ... x 0 0
176-
177-
// TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) *
178-
// (output_width-1)}
179-
// length of copy_size is equal kw.
180-
if (plw + prw < filter_width) {
181-
for (int oh = 0; oh < output_height; ++oh) {
182-
const T* im_data_start =
183-
im_data + (oh - plh > 0 ? oh - plh : 0) * im_width;
184-
T* dst_data = col_data + oh * output_width;
185-
for (int ic = 0; ic < im_channels; ++ic) {
186-
const T* src_data = im_data_start + ic * im_size;
187-
for (int kh = 0; kh < filter_height; ++kh) {
188-
if ((oh < plh && kh < plh) ||
189-
(oh > (output_height - prh - 1) &&
190-
kh > (filter_height - prh - 1))) {
191-
dst_data = dst_data + filter_width * col_matrix_width;
192-
continue;
193-
}
194-
// TODO(TJ): reuse plw-kw outside this for
195-
// try to unify
196-
for (int kw = 0; kw < plw; ++kw) {
197-
std::memcpy(dst_data + (plw - kw), src_data,
198-
sizeof(T) * (output_width - (plw - kw)));
199-
dst_data = dst_data + col_matrix_width;
200-
}
201-
for (int kw = plw; kw < filter_width - prw; ++kw) {
202-
std::memcpy(dst_data, src_data + (kw - plw),
203-
sizeof(T) * output_width);
204-
dst_data = dst_data + col_matrix_width;
205-
}
206-
int i = 1;
207-
for (int kw = filter_width - prw; kw < filter_width;
208-
++kw, ++i) {
209-
std::memcpy(dst_data, src_data + (kw - plw),
210-
sizeof(T) * (output_width - i));
211-
dst_data = dst_data + col_matrix_width;
212-
}
213-
src_data = src_data + im_width;
214-
}
215-
}
216-
}
217-
} else {
218-
LOG(FATAL) << "Not implement yet";
219-
}
220-
return;
221-
}
222-
}
223-
224-
for (int c = 0; c < channels_col; ++c) {
225-
int w_offset = c % filter_width;
226-
int h_offset = (c / filter_width) % filter_height;
227-
int c_im = c / (filter_width * filter_height);
228-
for (int h = 0; h < output_height; ++h) {
229-
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
230-
for (int w = 0; w < output_width; ++w) {
231-
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
232-
int col_idx = (c * output_height + h) * output_width + w;
233-
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
234-
235-
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
236-
im_col_idx < 0 || im_col_idx >= im_width)
237-
? static_cast<T>(0)
238-
: im_data[im_idx];
239-
}
44+
im2col_sh1sw1dh1dw1<T>(im, padding, col);
24045
}
46+
return;
24147
}
48+
im2col_common<T>(im, dilation, stride, padding, col);
24249
}
24350
};
24451

0 commit comments

Comments
 (0)