Skip to content

Commit 65d418f

Browse files
committed
complete im2col with padding==1 and speedup filter width==1
1 parent 52eb86e commit 65d418f

File tree

3 files changed

+113
-125
lines changed

3 files changed

+113
-125
lines changed

paddle/fluid/operators/math/im2col.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
4040
dilation[1] == 1) {
4141
if (padding[0] == 0 && padding[1] == 0) {
4242
im2col_sh1sw1dh1dw1ph0pw0<T>(im, col);
43-
} else {
44-
im2col_sh1sw1dh1dw1<T>(im, padding, col);
43+
return;
44+
} else if (padding[0] == 1 && padding[1] == 1) {
45+
im2col_sh1sw1dh1dw1ph1pw1<T>(im, col);
46+
return;
4547
}
46-
return;
48+
// TODO(TJ): complete padding >=2
4749
}
4850
im2col_common<T>(im, dilation, stride, padding, col);
4951
}

paddle/fluid/operators/math/im2col_cfo_cpu.h

Lines changed: 99 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace paddle {
2121
namespace operators {
2222
namespace math {
2323

24-
/*
24+
/**
2525
* The most common im2col algorithm.
2626
* Support dilation, stride and padding.
2727
*/
@@ -61,9 +61,9 @@ inline void im2col_common(const framework::Tensor& im,
6161
}
6262
}
6363

64-
/*
64+
/**
6565
* im2col algorithm with strides == 1, dilations == 1, paddings == 0
66-
* */
66+
*/
6767
template <typename T>
6868
inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
6969
framework::Tensor* col) {
@@ -96,131 +96,71 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
9696
}
9797
}
9898

99-
// further optimize: padding == 1 need special
99+
/**
100+
* im2col algorithm with strides == 1, dilations == 1, paddings == 1
101+
* and filter_width == 1 have a special implementation
102+
*/
100103
template <typename T>
101-
inline void im2col_sh1sw1dh1dw1(const framework::Tensor& im,
102-
const std::vector<int>& padding,
103-
framework::Tensor* col) {
104+
inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
105+
framework::Tensor* col) {
104106
int im_channels = im.dims()[0];
105107
int im_height = im.dims()[1];
106108
int im_width = im.dims()[2];
107109
int filter_height = col->dims()[1];
108110
int filter_width = col->dims()[2];
109111
int output_height = col->dims()[3];
110112
int output_width = col->dims()[4];
111-
constexpr int sh = 1;
112-
constexpr int sw = 1;
113+
114+
constexpr int plh = 1;
115+
constexpr int prh = 1;
116+
constexpr int plw = 1;
117+
constexpr int prw = 1;
113118

114119
const T* im_data = im.data<T>();
115120
T* col_data = col->data<T>();
116-
int col_matrix_width = output_width * output_height;
117121
int im_size = im_height * im_width;
118-
119-
int plh = padding[0];
120-
int plw = padding[1];
121-
int prh = (output_height - 1) * sh + filter_height - im_height - plh;
122-
int prw = (output_width - 1) * sw + filter_width - im_width - plw;
123-
124-
// fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
125-
// TODO(TJ): refine ph*xxx
126-
assert(plh == prh); // because stride_h == 1
122+
int col_matrix_width = output_width * output_height;
127123
int col_block_fh = filter_width * col_matrix_width; // fw*oh*ow
128124
int col_block_ic = filter_height * col_block_fh; // fh*fw*oh*ow
129-
for (int ph = 0; ph < plh; ++ph) {
130-
int sz = output_width * (plh - ph);
131-
size_t copy_sz = sizeof(T) * sz;
132-
T* col_start_l = col_data + ph * col_block_fh;
133-
T* col_start_r = col_data + (filter_height - ph - 1) * col_block_fh +
134-
col_matrix_width - sz;
125+
126+
// fill height padding
127+
{
128+
size_t copy_size = sizeof(T) * output_width;
129+
T* col_start_l = col_data;
130+
T* col_start_r = col_data + (filter_height - 1) * col_block_fh +
131+
col_matrix_width - output_width;
135132
for (int ic = 0; ic < im_channels; ++ic) {
133+
// TODO(TJ): move * outside
136134
T* dst_data_l = col_start_l + ic * col_block_ic;
137135
T* dst_data_r = col_start_r + ic * col_block_ic;
138136
for (int kw = 0; kw < filter_width; ++kw) {
139-
std::memset(dst_data_l, 0, copy_sz);
140-
std::memset(dst_data_r, 0, copy_sz);
137+
std::memset(dst_data_l, 0, copy_size);
138+
std::memset(dst_data_r, 0, copy_size);
141139
dst_data_l = dst_data_l + col_matrix_width;
142140
dst_data_r = dst_data_r + col_matrix_width;
143141
}
144142
}
145143
}
146144

147-
// fill width padding
148-
assert(plw == prw); // because stride_w == 1
149-
if (plw == 1) {
150-
auto pad = static_cast<T>(0); // padding zero
145+
auto pad = static_cast<T>(0);
146+
if (filter_width == 1) {
147+
// fill width padding
151148
for (int ic = 0; ic < im_channels; ++ic) {
152-
// TODO(TJ): use add and resue stride
149+
// TODO(TJ): move * outside
153150
T* dst_data_ic = col_data + ic * col_block_ic;
154151
for (int kh = 0; kh < filter_height; ++kh) {
155-
T* dst_data_kh = dst_data_ic + kh * col_block_fh;
156-
for (T* dst_data :
157-
{dst_data_kh, dst_data_kh +
158-
(filter_width - prw) * col_matrix_width +
159-
output_width - 1}) {
160-
// TODO(TJ): from plh, saving repeated assignment
161-
for (int oh = 0; oh < output_height; ++oh) {
162-
*dst_data = pad;
163-
dst_data = dst_data + output_width;
164-
}
152+
// TODO(TJ): move * outside
153+
T* dst_data = dst_data_ic + kh * col_block_fh;
154+
for (int oh = 0; oh < output_height; ++oh) {
155+
*dst_data = pad;
156+
dst_data = dst_data + output_width - 1;
157+
*dst_data = pad;
158+
++dst_data;
165159
}
166160
}
167161
}
168-
} else {
169-
// padding_size > 1
170-
for (int ic = 0; ic < im_channels; ++ic) {
171-
// TODO(TJ): use add and resue stride
172-
T* dst_data_ic = col_data + ic * col_block_ic;
173-
for (int kh = 0; kh < filter_height; ++kh) {
174-
T* dst_data_kh = dst_data_ic + kh * col_block_fh;
175-
for (int kw = 0; kw < plw; ++kw) {
176-
// TODO(TJ): reuse array outside this for
177-
size_t sz = sizeof(T) * (plw - kw);
178-
T* dst_data = dst_data_kh + kw * col_matrix_width;
179-
// TODO(TJ): from plh, saving repeated assignment
180-
for (int oh = 0; oh < output_height; ++oh) {
181-
std::memset(dst_data, 0, sz);
182-
dst_data = dst_data + output_width;
183-
}
184-
}
185-
// TODO(TJ): use reverse to save cache
186-
for (int kw = 0; kw < prw; ++kw) {
187-
// TODO(TJ): reuse array outside this for
188-
auto num = (prw - kw);
189-
size_t sz = sizeof(T) * num;
190-
T* dst_data = dst_data_kh +
191-
(filter_width - 1 - kw) * col_matrix_width +
192-
output_width - num;
193-
// TODO(TJ): from plh, saving repeated assignment
194-
for (int oh = 0; oh < output_height; ++oh) {
195-
std::memset(dst_data, 0, sz);
196-
dst_data = dst_data + output_width;
197-
}
198-
}
199-
}
200-
}
201-
}
202-
203-
// fill im_data
204-
// padding cover two cases:
205-
// 1. kw > 2*pw: kw = 3, pw = 1
206-
// 0 x x x x ... x x x x 0
207-
// 1 1 1 1 1 1
208-
// ==>
209-
// 0 x ... x x
210-
// x x ... x x
211-
// x x ... x 0
212-
// 2. kw < 2*pw: kw = 3, pw = 2
213-
// 0 0 x x x ... x x x 0 0
214-
// 1 1 1 1 1 1
215-
// ==>
216-
// 0 0 x ... x x x
217-
// 0 x x ... x x 0
218-
// x x x ... x 0 0
219-
220-
// TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) *
221-
// (output_width-1)}
222-
// length of copy_size is equal kw.
223-
if (plw + prw < filter_width) {
162+
// fill core
163+
size_t copy_size = sizeof(T) * (output_width - plw - prw);
224164
for (int oh = 0; oh < output_height; ++oh) {
225165
const T* im_data_start =
226166
im_data + (oh - plh > 0 ? oh - plh : 0) * im_width;
@@ -230,33 +170,73 @@ inline void im2col_sh1sw1dh1dw1(const framework::Tensor& im,
230170
for (int kh = 0; kh < filter_height; ++kh) {
231171
if ((oh < plh && kh < plh) || (oh > (output_height - prh - 1) &&
232172
kh > (filter_height - prh - 1))) {
233-
dst_data = dst_data + filter_width * col_matrix_width;
234-
continue;
235-
}
236-
// TODO(TJ): reuse plw-kw outside this for
237-
// try to unify
238-
for (int kw = 0; kw < plw; ++kw) {
239-
std::memcpy(dst_data + (plw - kw), src_data,
240-
sizeof(T) * (output_width - (plw - kw)));
241-
dst_data = dst_data + col_matrix_width;
242-
}
243-
for (int kw = plw; kw < filter_width - prw; ++kw) {
244-
std::memcpy(dst_data, src_data + (kw - plw),
245-
sizeof(T) * output_width);
246-
dst_data = dst_data + col_matrix_width;
247-
}
248-
int i = 1;
249-
for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) {
250-
std::memcpy(dst_data, src_data + (kw - plw),
251-
sizeof(T) * (output_width - i));
252173
dst_data = dst_data + col_matrix_width;
174+
continue;
253175
}
176+
std::memcpy(dst_data + plw, src_data, copy_size);
177+
dst_data = dst_data + col_matrix_width;
254178
src_data = src_data + im_width;
255179
}
256180
}
257181
}
258-
} else {
259-
LOG(FATAL) << "Not implement yet";
182+
return;
183+
}
184+
185+
// filter_width != 1
186+
// fill width padding
187+
for (int ic = 0; ic < im_channels; ++ic) {
188+
// TODO(TJ): move * outside
189+
T* dst_data_ic = col_data + ic * col_block_ic;
190+
for (int kh = 0; kh < filter_height; ++kh) {
191+
// TODO(TJ): move * outside
192+
T* dst_data_kh = dst_data_ic + kh * col_block_fh;
193+
for (T* dst_data :
194+
{dst_data_kh, dst_data_kh + (filter_width - prw) * col_matrix_width +
195+
output_width - 1}) {
196+
// TODO(TJ): from plh, saving repeated assignment
197+
for (int oh = 0; oh < output_height; ++oh) {
198+
*dst_data = pad;
199+
dst_data = dst_data + output_width;
200+
}
201+
}
202+
}
203+
}
204+
205+
// TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) *
206+
// (output_width-1)}
207+
// length of copy_size is equal kw.
208+
for (int oh = 0; oh < output_height; ++oh) {
209+
const T* im_data_start = im_data + (oh - plh > 0 ? oh - plh : 0) * im_width;
210+
T* dst_data = col_data + oh * output_width;
211+
for (int ic = 0; ic < im_channels; ++ic) {
212+
const T* src_data = im_data_start + ic * im_size;
213+
for (int kh = 0; kh < filter_height; ++kh) {
214+
if ((oh < plh && kh < plh) || (oh > (output_height - prh - 1) &&
215+
kh > (filter_height - prh - 1))) {
216+
dst_data = dst_data + filter_width * col_matrix_width;
217+
continue;
218+
}
219+
// TODO(TJ): reuse plw-kw outside this for
220+
// try to unify
221+
for (int kw = 0; kw < plw; ++kw) {
222+
std::memcpy(dst_data + (plw - kw), src_data,
223+
sizeof(T) * (output_width - (plw - kw)));
224+
dst_data = dst_data + col_matrix_width;
225+
}
226+
for (int kw = plw; kw < filter_width - prw; ++kw) {
227+
std::memcpy(dst_data, src_data + (kw - plw),
228+
sizeof(T) * output_width);
229+
dst_data = dst_data + col_matrix_width;
230+
}
231+
int i = 1;
232+
for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) {
233+
std::memcpy(dst_data, src_data + (kw - plw),
234+
sizeof(T) * (output_width - i));
235+
dst_data = dst_data + col_matrix_width;
236+
}
237+
src_data = src_data + im_width;
238+
}
239+
}
260240
}
261241
}
262242

paddle/fluid/operators/math/im2col_test.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ void benchIm2col(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
227227
auto t3 = GetCurrentMs();
228228

229229
LOG(INFO) << "before: " << (t3 - t2) / repeat
230-
<< ",after: " << (t2 - t1) / repeat;
230+
<< ",after: " << (t2 - t1) / repeat
231+
<< ",boost: " << ((t3 - t2) / (t2 - t1) - 1) * 100 << "%";
231232
}
232233

233234
TEST(math, im2col_cputest) {
@@ -244,20 +245,25 @@ TEST(math, im2col_cputest) {
244245
// height != width
245246
testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 4, /*fh*/ 2, /*fw*/ 3, /*ph*/ p,
246247
/*pw*/ p);
248+
testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 4, /*fh*/ 1, /*fw*/ 3, /*ph*/ p,
249+
/*pw*/ p);
250+
testIm2colCPU(/*ic*/ 2, /*ih*/ 4, /*iw*/ 5, /*fh*/ 3, /*fw*/ 1, /*ph*/ p,
251+
/*pw*/ p);
247252

248253
// filter == 1
249254
testIm2colCPU(/*ic*/ 3, /*ih*/ 4, /*iw*/ 4, /*fh*/ 1, /*fw*/ 1, /*ph*/ p,
250255
/*pw*/ p);
251256
testIm2colCPU(/*ic*/ 3, /*ih*/ 3, /*iw*/ 4, /*fh*/ 1, /*fw*/ 1, /*ph*/ p,
252257
/*pw*/ p);
253258
}
259+
254260
// padding_h != padding_w
255261
testIm2colCPU(/*ic*/ 2, /*ih*/ 4, /*iw*/ 4, /*fh*/ 2, /*fw*/ 3, /*ph*/ 1,
256262
/*pw*/ 2);
257263

258264
// benchmark
259-
for (int p : {0, 1, 2}) {
260-
for (int k : {3, 5}) {
265+
for (int p : {0, 1}) {
266+
for (int k : {1, 3, 5}) {
261267
LOG(INFO) << "padding == " << p << ", filter == " << k;
262268
benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ k, /*fw*/ k,
263269
/*ph*/ p, /*pw*/ p);

0 commit comments

Comments
 (0)