@@ -40,22 +40,46 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
40
40
int im_width = im.dims ()[2 ];
41
41
int filter_height = col->dims ()[1 ];
42
42
int filter_width = col->dims ()[2 ];
43
- int col_height = col->dims ()[3 ];
44
- int col_width = col->dims ()[4 ];
43
+ int output_height = col->dims ()[3 ];
44
+ int output_width = col->dims ()[4 ];
45
45
46
46
int channels_col = im_channels * filter_height * filter_width;
47
47
48
48
const T* im_data = im.data <T>();
49
49
T* col_data = col->data <T>();
50
+ // TODO(TJ): change me to template
51
+ // further optimaze:
52
+ // 1. padding != 1
53
+ // 2. could also support stride_h != 1
54
+ if (stride[0 ] == 1 && stride[1 ] == 1 && dilation[0 ] == 1 &&
55
+ dilation[1 ] == 1 && padding[0 ] == 0 && padding[1 ] == 0 ) {
56
+ int col_matrix_width = output_width * output_height;
57
+ for (int oh = 0 ; oh < output_height; ++oh) {
58
+ const T* im_data_start = im_data + oh * im_width;
59
+ T* dst_data = col_data + oh * output_width;
60
+ for (int ic = 0 ; ic < im_channels; ++ic) {
61
+ const T* src_data = im_data_start + ic * im_height * im_width;
62
+ for (int kh = 0 ; kh < filter_height; ++kh) {
63
+ for (int kw = 0 ; kw < filter_width; ++kw) {
64
+ std::memcpy (dst_data, src_data + kw, sizeof (T) * output_width);
65
+ dst_data = dst_data + col_matrix_width;
66
+ }
67
+ src_data = src_data + im_width;
68
+ }
69
+ }
70
+ }
71
+ return ;
72
+ }
73
+
50
74
for (int c = 0 ; c < channels_col; ++c) {
51
75
int w_offset = c % filter_width;
52
76
int h_offset = (c / filter_width) % filter_height;
53
77
int c_im = c / (filter_width * filter_height);
54
- for (int h = 0 ; h < col_height ; ++h) {
78
+ for (int h = 0 ; h < output_height ; ++h) {
55
79
int im_row_idx = h * stride[0 ] - padding[0 ] + h_offset * dilation[0 ];
56
- for (int w = 0 ; w < col_width ; ++w) {
80
+ for (int w = 0 ; w < output_width ; ++w) {
57
81
int im_col_idx = w * stride[1 ] - padding[1 ] + w_offset * dilation[1 ];
58
- int col_idx = (c * col_height + h) * col_width + w;
82
+ int col_idx = (c * output_height + h) * output_width + w;
59
83
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
60
84
61
85
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
0 commit comments