Skip to content

Commit 92518c5

Browse files
committed
reuse sizes saving time
1 parent 660df12 commit 92518c5

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

paddle/fluid/operators/math/im2col.cc

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,22 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
7979
// plw;
8080

8181
// fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
82-
// TODO(TJ): reuse sizes
82+
// TODO(TJ): refine ph*xxx
8383
assert(plh == prh); // because stride_h == 1
84+
int col_block_fh = filter_width * col_matrix_width; // fw*oh*ow
85+
int col_block_ic = filter_height * col_block_fh; // fh*fw*oh*ow
8486
for (int ph = 0; ph < plh; ++ph) {
85-
size_t sz = sizeof(T) * output_width * (plh - ph);
86-
T* col_start_l = col_data + ph * filter_width * col_matrix_width;
87-
T* col_start_r =
88-
col_data +
89-
(filter_width - ph - 1) * filter_width * col_matrix_width +
90-
col_matrix_width - output_width * (plh - ph);
87+
int sz = output_width * (plh - ph);
88+
size_t copy_sz = sizeof(T) * sz;
89+
T* col_start_l = col_data + ph * col_block_fh;
90+
T* col_start_r = col_data + (filter_height - ph - 1) * col_block_fh +
91+
col_matrix_width - sz;
9192
for (int ic = 0; ic < im_channels; ++ic) {
92-
T* dst_data_l =
93-
col_start_l +
94-
ic * filter_width * filter_height * col_matrix_width;
95-
T* dst_data_r =
96-
col_start_r +
97-
ic * filter_width * filter_height * col_matrix_width;
93+
T* dst_data_l = col_start_l + ic * col_block_ic;
94+
T* dst_data_r = col_start_r + ic * col_block_ic;
9895
for (int kw = 0; kw < filter_width; ++kw) {
99-
std::memset(dst_data_l, 0, sz);
100-
std::memset(dst_data_r, 0, sz);
96+
std::memset(dst_data_l, 0, copy_sz);
97+
std::memset(dst_data_r, 0, copy_sz);
10198
dst_data_l = dst_data_l + col_matrix_width;
10299
dst_data_r = dst_data_r + col_matrix_width;
103100
}

0 commit comments

Comments
 (0)