Skip to content

Commit 059b278

Browse files
authored
Merge pull request #12408 from tensor-tang/refine/im2col
Refine CPU im2col padding with 1
2 parents b0cf1fe + d8d2dbc commit 059b278

File tree

3 files changed

+365
-124
lines changed

3 files changed

+365
-124
lines changed

paddle/fluid/operators/math/im2col.cc

Lines changed: 10 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/im2col.h"
1616
#include <vector>
17+
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
1718

1819
namespace paddle {
1920
namespace operators {
@@ -35,61 +36,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
3536
PADDLE_ENFORCE(im.dims().size() == 3);
3637
PADDLE_ENFORCE(col->dims().size() == 5);
3738

38-
int im_channels = im.dims()[0];
39-
int im_height = im.dims()[1];
40-
int im_width = im.dims()[2];
41-
int filter_height = col->dims()[1];
42-
int filter_width = col->dims()[2];
43-
int output_height = col->dims()[3];
44-
int output_width = col->dims()[4];
45-
46-
int channels_col = im_channels * filter_height * filter_width;
47-
48-
const T* im_data = im.data<T>();
49-
T* col_data = col->data<T>();
50-
// TODO(TJ): change me to template
51-
// further optimaze:
52-
// 1. padding != 1
53-
// 2. could also support stride_h != 1
5439
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
55-
dilation[1] == 1 && padding[0] == 0 && padding[1] == 0) {
56-
int col_matrix_width = output_width * output_height;
57-
size_t copy_size = sizeof(T) * output_width;
58-
for (int oh = 0; oh < output_height; ++oh) {
59-
const T* im_data_start = im_data + oh * im_width;
60-
T* dst_data = col_data + oh * output_width;
61-
for (int ic = 0; ic < im_channels; ++ic) {
62-
const T* src_data = im_data_start + ic * im_height * im_width;
63-
for (int kh = 0; kh < filter_height; ++kh) {
64-
for (int kw = 0; kw < filter_width; ++kw) {
65-
std::memcpy(dst_data, src_data + kw, copy_size);
66-
dst_data = dst_data + col_matrix_width;
67-
}
68-
src_data = src_data + im_width;
69-
}
70-
}
71-
}
72-
return;
73-
}
74-
75-
for (int c = 0; c < channels_col; ++c) {
76-
int w_offset = c % filter_width;
77-
int h_offset = (c / filter_width) % filter_height;
78-
int c_im = c / (filter_width * filter_height);
79-
for (int h = 0; h < output_height; ++h) {
80-
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
81-
for (int w = 0; w < output_width; ++w) {
82-
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
83-
int col_idx = (c * output_height + h) * output_width + w;
84-
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
85-
86-
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
87-
im_col_idx < 0 || im_col_idx >= im_width)
88-
? static_cast<T>(0)
89-
: im_data[im_idx];
90-
}
40+
dilation[1] == 1) {
41+
if (padding[0] == 0 && padding[1] == 0) {
42+
im2col_sh1sw1dh1dw1ph0pw0<T>(im, col);
43+
return;
44+
} else if (padding[0] == 1 && padding[1] == 1) {
45+
im2col_sh1sw1dh1dw1ph1pw1<T>(im, col);
46+
return;
9147
}
48+
// TODO(TJ): complete padding >=2
9249
}
50+
im2col_common<T>(im, dilation, stride, padding, col);
9351
}
9452
};
9553

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "paddle/fluid/framework/tensor.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
namespace math {
23+
24+
/**
25+
* The most common im2col algorithm.
26+
* Support dilation, stride and padding.
27+
*/
28+
template <typename T>
29+
inline void im2col_common(const framework::Tensor& im,
30+
const std::vector<int>& dilation,
31+
const std::vector<int>& stride,
32+
const std::vector<int>& padding,
33+
framework::Tensor* col) {
34+
int im_channels = im.dims()[0];
35+
int im_height = im.dims()[1];
36+
int im_width = im.dims()[2];
37+
int filter_height = col->dims()[1];
38+
int filter_width = col->dims()[2];
39+
int output_height = col->dims()[3];
40+
int output_width = col->dims()[4];
41+
int channels_col = im_channels * filter_height * filter_width;
42+
43+
const T* im_data = im.data<T>();
44+
T* col_data = col->data<T>();
45+
for (int c = 0; c < channels_col; ++c) {
46+
int w_offset = c % filter_width;
47+
int h_offset = (c / filter_width) % filter_height;
48+
int c_im = c / (filter_width * filter_height);
49+
for (int h = 0; h < output_height; ++h) {
50+
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
51+
for (int w = 0; w < output_width; ++w) {
52+
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
53+
int col_idx = (c * output_height + h) * output_width + w;
54+
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
55+
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
56+
im_col_idx < 0 || im_col_idx >= im_width)
57+
? static_cast<T>(0)
58+
: im_data[im_idx];
59+
}
60+
}
61+
}
62+
}
63+
64+
/**
65+
* im2col algorithm with strides == 1, dilations == 1, paddings == 0
66+
*/
67+
template <typename T>
68+
inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
69+
framework::Tensor* col) {
70+
int im_channels = im.dims()[0];
71+
int im_height = im.dims()[1];
72+
int im_width = im.dims()[2];
73+
int filter_height = col->dims()[1];
74+
int filter_width = col->dims()[2];
75+
int output_height = col->dims()[3];
76+
int output_width = col->dims()[4];
77+
78+
const T* im_data = im.data<T>();
79+
T* col_data = col->data<T>();
80+
int col_matrix_width = output_width * output_height;
81+
int im_size = im_height * im_width;
82+
size_t copy_size = sizeof(T) * output_width;
83+
const T* im_data_oh = im_data;
84+
T* dst_data_oh = col_data;
85+
for (int oh = 0; oh < output_height; ++oh) {
86+
const T* src_data_ic = im_data_oh;
87+
T* dst_data = dst_data_oh;
88+
for (int ic = 0; ic < im_channels; ++ic) {
89+
const T* src_data = src_data_ic;
90+
for (int kh = 0; kh < filter_height; ++kh) {
91+
for (int kw = 0; kw < filter_width; ++kw) {
92+
std::memcpy(dst_data, src_data + kw, copy_size);
93+
dst_data = dst_data + col_matrix_width;
94+
}
95+
src_data = src_data + im_width;
96+
}
97+
src_data_ic = src_data_ic + im_size;
98+
}
99+
im_data_oh = im_data_oh + im_width;
100+
dst_data_oh = dst_data_oh + output_width;
101+
}
102+
}
103+
104+
/**
105+
* im2col algorithm with strides == 1, dilations == 1, paddings == 1
106+
* and filter_width == 1 have a special implementation
107+
*/
108+
template <typename T>
109+
inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
110+
framework::Tensor* col) {
111+
int im_channels = im.dims()[0];
112+
int im_height = im.dims()[1];
113+
int im_width = im.dims()[2];
114+
int filter_height = col->dims()[1];
115+
int filter_width = col->dims()[2];
116+
int output_height = col->dims()[3];
117+
int output_width = col->dims()[4];
118+
119+
constexpr int plh = 1;
120+
constexpr int prh = 1;
121+
constexpr int plw = 1;
122+
constexpr int prw = 1;
123+
124+
const T* im_data = im.data<T>();
125+
T* col_data = col->data<T>();
126+
int im_size = im_height * im_width;
127+
int col_matrix_width = output_width * output_height;
128+
int col_block_fh = filter_width * col_matrix_width; // fw*oh*ow
129+
int col_block_ic = filter_height * col_block_fh; // fh*fw*oh*ow
130+
131+
// fill height padding
132+
{
133+
size_t copy_size = sizeof(T) * output_width;
134+
T* col_start_l = col_data;
135+
T* col_start_r = col_data + (filter_height - 1) * col_block_fh +
136+
col_matrix_width - output_width;
137+
for (int ic = 0; ic < im_channels; ++ic) {
138+
T* dst_data_l = col_start_l;
139+
T* dst_data_r = col_start_r;
140+
for (int kw = 0; kw < filter_width; ++kw) {
141+
std::memset(dst_data_l, 0, copy_size);
142+
std::memset(dst_data_r, 0, copy_size);
143+
dst_data_l = dst_data_l + col_matrix_width;
144+
dst_data_r = dst_data_r + col_matrix_width;
145+
}
146+
col_start_l = col_start_l + col_block_ic;
147+
col_start_r = col_start_r + col_block_ic;
148+
}
149+
}
150+
151+
auto pad = static_cast<T>(0);
152+
if (filter_width == 1) {
153+
// fill width padding
154+
T* dst_data_ic = col_data;
155+
for (int ic = 0; ic < im_channels; ++ic) {
156+
T* dst_data_kh = dst_data_ic;
157+
for (int kh = 0; kh < filter_height; ++kh) {
158+
T* dst_data = dst_data_kh;
159+
for (int oh = 0; oh < output_height; ++oh) {
160+
*dst_data = pad;
161+
dst_data = dst_data + output_width - 1;
162+
*dst_data = pad;
163+
++dst_data;
164+
}
165+
dst_data_kh = dst_data_kh + col_block_fh;
166+
}
167+
dst_data_ic = dst_data_ic + col_block_ic;
168+
}
169+
// fill core
170+
size_t copy_size = sizeof(T) * (output_width - plw - prw);
171+
for (int oh = 0; oh < output_height; ++oh) {
172+
const T* im_data_start =
173+
im_data + (oh - plh > 0 ? oh - plh : 0) * im_width;
174+
T* dst_data = col_data + oh * output_width;
175+
for (int ic = 0; ic < im_channels; ++ic) {
176+
const T* src_data = im_data_start + ic * im_size;
177+
for (int kh = 0; kh < filter_height; ++kh) {
178+
if ((oh < plh && kh < plh) || (oh > (output_height - prh - 1) &&
179+
kh > (filter_height - prh - 1))) {
180+
dst_data = dst_data + col_matrix_width;
181+
continue;
182+
}
183+
std::memcpy(dst_data + plw, src_data, copy_size);
184+
dst_data = dst_data + col_matrix_width;
185+
src_data = src_data + im_width;
186+
}
187+
}
188+
}
189+
return;
190+
}
191+
192+
// filter_width != 1
193+
// fill width padding
194+
T* dst_data_ic = col_data;
195+
for (int ic = 0; ic < im_channels; ++ic) {
196+
T* dst_data_kh = dst_data_ic;
197+
for (int kh = 0; kh < filter_height; ++kh) {
198+
for (T* dst_data :
199+
{dst_data_kh, dst_data_kh + (filter_width - prw) * col_matrix_width +
200+
output_width - 1}) {
201+
// TODO(TJ): from plh, saving repeated assignment
202+
for (int oh = 0; oh < output_height; ++oh) {
203+
*dst_data = pad;
204+
dst_data = dst_data + output_width;
205+
}
206+
}
207+
dst_data_kh = dst_data_kh + col_block_fh;
208+
}
209+
dst_data_ic = dst_data_ic + col_block_ic;
210+
}
211+
212+
// TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) *
213+
// (output_width-1)}
214+
// length of copy_size is equal kw.
215+
for (int oh = 0; oh < output_height; ++oh) {
216+
const T* im_data_start = im_data + (oh - plh > 0 ? oh - plh : 0) * im_width;
217+
T* dst_data = col_data + oh * output_width;
218+
for (int ic = 0; ic < im_channels; ++ic) {
219+
const T* src_data = im_data_start + ic * im_size;
220+
for (int kh = 0; kh < filter_height; ++kh) {
221+
if ((oh < plh && kh < plh) || (oh > (output_height - prh - 1) &&
222+
kh > (filter_height - prh - 1))) {
223+
dst_data = dst_data + filter_width * col_matrix_width;
224+
continue;
225+
}
226+
// TODO(TJ): reuse plw-kw outside this for
227+
// try to unify
228+
for (int kw = 0; kw < plw; ++kw) {
229+
std::memcpy(dst_data + (plw - kw), src_data,
230+
sizeof(T) * (output_width - (plw - kw)));
231+
dst_data = dst_data + col_matrix_width;
232+
}
233+
for (int kw = plw; kw < filter_width - prw; ++kw) {
234+
std::memcpy(dst_data, src_data + (kw - plw),
235+
sizeof(T) * output_width);
236+
dst_data = dst_data + col_matrix_width;
237+
}
238+
int i = 1;
239+
for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) {
240+
std::memcpy(dst_data, src_data + (kw - plw),
241+
sizeof(T) * (output_width - i));
242+
dst_data = dst_data + col_matrix_width;
243+
}
244+
src_data = src_data + im_width;
245+
}
246+
}
247+
}
248+
}
249+
250+
} // namespace math
251+
} // namespace operators
252+
} // namespace paddle

0 commit comments

Comments
 (0)