Skip to content

Commit 179dd0c

Browse files
authored
Merge pull request #12337 from tensor-tang/refine/im2col
refine cpu im2col no padding
2 parents 2d21aa7 + b72befc commit 179dd0c

File tree

2 files changed

+102
-5
lines changed

2 files changed

+102
-5
lines changed

paddle/fluid/operators/math/im2col.cc

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,47 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
4040
int im_width = im.dims()[2];
4141
int filter_height = col->dims()[1];
4242
int filter_width = col->dims()[2];
43-
int col_height = col->dims()[3];
44-
int col_width = col->dims()[4];
43+
int output_height = col->dims()[3];
44+
int output_width = col->dims()[4];
4545

4646
int channels_col = im_channels * filter_height * filter_width;
4747

4848
const T* im_data = im.data<T>();
4949
T* col_data = col->data<T>();
50+
// TODO(TJ): change me to template
51+
// further optimaze:
52+
// 1. padding != 1
53+
// 2. could also support stride_h != 1
54+
if (stride[0] == 1 && stride[1] == 1 && dilation[0] == 1 &&
55+
dilation[1] == 1 && padding[0] == 0 && padding[1] == 0) {
56+
int col_matrix_width = output_width * output_height;
57+
size_t copy_size = sizeof(T) * output_width;
58+
for (int oh = 0; oh < output_height; ++oh) {
59+
const T* im_data_start = im_data + oh * im_width;
60+
T* dst_data = col_data + oh * output_width;
61+
for (int ic = 0; ic < im_channels; ++ic) {
62+
const T* src_data = im_data_start + ic * im_height * im_width;
63+
for (int kh = 0; kh < filter_height; ++kh) {
64+
for (int kw = 0; kw < filter_width; ++kw) {
65+
std::memcpy(dst_data, src_data + kw, copy_size);
66+
dst_data = dst_data + col_matrix_width;
67+
}
68+
src_data = src_data + im_width;
69+
}
70+
}
71+
}
72+
return;
73+
}
74+
5075
for (int c = 0; c < channels_col; ++c) {
5176
int w_offset = c % filter_width;
5277
int h_offset = (c / filter_width) % filter_height;
5378
int c_im = c / (filter_width * filter_height);
54-
for (int h = 0; h < col_height; ++h) {
79+
for (int h = 0; h < output_height; ++h) {
5580
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
56-
for (int w = 0; w < col_width; ++w) {
81+
for (int w = 0; w < output_width; ++w) {
5782
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
58-
int col_idx = (c * col_height + h) * col_width + w;
83+
int col_idx = (c * output_height + h) * output_width + w;
5984
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
6085

6186
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||

paddle/fluid/operators/math/im2col_test.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,80 @@ void testIm2col() {
160160
delete context;
161161
}
162162

163+
void testIm2colCPU(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
164+
paddle::framework::Tensor input;
165+
paddle::framework::Tensor output;
166+
paddle::framework::Tensor ref_output;
167+
std::vector<int> padding({ph, pw});
168+
std::vector<int> stride({1, 1}); // stride_y, stride_x
169+
std::vector<int> dilation({1, 1}); // dilation_y, dilation_x
170+
int output_height = (ih - fh + padding[0] * 2) / stride[0] + 1;
171+
int output_width = (iw - fw + padding[1] * 2) / stride[1] + 1;
172+
float* input_ptr =
173+
input.mutable_data<float>({ic, ih, iw}, paddle::platform::CPUPlace());
174+
for (int i = 0; i < input.numel(); ++i) {
175+
input_ptr[i] = static_cast<float>(i + 1);
176+
}
177+
178+
paddle::platform::CPUPlace place;
179+
paddle::platform::CPUDeviceContext context(place);
180+
output.mutable_data<float>({ic, fh, fw, output_height, output_width}, place);
181+
ref_output.mutable_data<float>({ic, fh, fw, output_height, output_width},
182+
place);
183+
paddle::operators::math::Im2ColFunctor<
184+
paddle::operators::math::ColFormat::kCFO,
185+
paddle::platform::CPUDeviceContext, float>
186+
im2col;
187+
im2col(context, input, dilation, stride, padding, &output);
188+
auto ref_im2col = [&](
189+
const paddle::framework::Tensor& im, const std::vector<int>& dilation,
190+
const std::vector<int>& stride, const std::vector<int>& padding,
191+
paddle::framework::Tensor* col) {
192+
int im_channels = im.dims()[0];
193+
int im_height = im.dims()[1];
194+
int im_width = im.dims()[2];
195+
int filter_height = col->dims()[1];
196+
int filter_width = col->dims()[2];
197+
int output_height = col->dims()[3];
198+
int output_width = col->dims()[4];
199+
int channels_col = im_channels * filter_height * filter_width;
200+
201+
const float* im_data = im.data<float>();
202+
float* col_data = col->data<float>();
203+
for (int c = 0; c < channels_col; ++c) {
204+
int w_offset = c % filter_width;
205+
int h_offset = (c / filter_width) % filter_height;
206+
int c_im = c / (filter_width * filter_height);
207+
for (int h = 0; h < output_height; ++h) {
208+
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
209+
for (int w = 0; w < output_width; ++w) {
210+
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
211+
int col_idx = (c * output_height + h) * output_width + w;
212+
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
213+
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
214+
im_col_idx < 0 || im_col_idx >= im_width)
215+
? 0.f
216+
: im_data[im_idx];
217+
}
218+
}
219+
}
220+
};
221+
222+
ref_im2col(input, dilation, stride, padding, &ref_output);
223+
224+
float* out_cfo_ptr = output.data<float>();
225+
float* out_ref_ptr = ref_output.data<float>();
226+
for (int i = 0; i < output.numel(); ++i) {
227+
EXPECT_EQ(out_cfo_ptr[i], out_ref_ptr[i]);
228+
}
229+
}
230+
163231
TEST(math, im2col) {
164232
testIm2col<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>();
233+
testIm2colCPU(/*ic*/ 3, /*ih*/ 5, /*iw*/ 5, /*fh*/ 3, /*fw*/ 2, /*ph*/ 0,
234+
/*pw*/ 0);
235+
testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 4, /*fh*/ 3, /*fw*/ 3, /*ph*/ 1,
236+
/*pw*/ 1);
165237
#ifdef PADDLE_WITH_CUDA
166238
testIm2col<paddle::platform::CUDADeviceContext,
167239
paddle::platform::CUDAPlace>();

0 commit comments

Comments
 (0)