@@ -79,25 +79,22 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
79
79
// plw;
80
80
81
81
// fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
82
- // TODO(TJ): reuse sizes
82
+ // TODO(TJ): refine ph*xxx
83
83
assert (plh == prh); // because stride_h == 1
84
+ int col_block_fh = filter_width * col_matrix_width; // fw*oh*ow
85
+ int col_block_ic = filter_height * col_block_fh; // fh*fw*oh*ow
84
86
for (int ph = 0 ; ph < plh; ++ph) {
85
- size_t sz = sizeof (T) * output_width * (plh - ph);
86
- T* col_start_l = col_data + ph * filter_width * col_matrix_width;
87
- T* col_start_r =
88
- col_data +
89
- (filter_width - ph - 1 ) * filter_width * col_matrix_width +
90
- col_matrix_width - output_width * (plh - ph);
87
+ int sz = output_width * (plh - ph);
88
+ size_t copy_sz = sizeof (T) * sz;
89
+ T* col_start_l = col_data + ph * col_block_fh;
90
+ T* col_start_r = col_data + (filter_height - ph - 1 ) * col_block_fh +
91
+ col_matrix_width - sz;
91
92
for (int ic = 0 ; ic < im_channels; ++ic) {
92
- T* dst_data_l =
93
- col_start_l +
94
- ic * filter_width * filter_height * col_matrix_width;
95
- T* dst_data_r =
96
- col_start_r +
97
- ic * filter_width * filter_height * col_matrix_width;
93
+ T* dst_data_l = col_start_l + ic * col_block_ic;
94
+ T* dst_data_r = col_start_r + ic * col_block_ic;
98
95
for (int kw = 0 ; kw < filter_width; ++kw) {
99
- std::memset (dst_data_l, 0 , sz );
100
- std::memset (dst_data_r, 0 , sz );
96
+ std::memset (dst_data_l, 0 , copy_sz );
97
+ std::memset (dst_data_r, 0 , copy_sz );
101
98
dst_data_l = dst_data_l + col_matrix_width;
102
99
dst_data_r = dst_data_r + col_matrix_width;
103
100
}
0 commit comments