@@ -48,29 +48,63 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
48
48
const T* im_data = im.data <T>();
49
49
T* col_data = col->data <T>();
50
50
// TODO(TJ): change me to template
51
- // further optimaze:
52
- // 1. padding != 1
53
- // 2. could also support stride_h != 1
51
+ // further optimize: padding == 1 need special
54
52
if (stride[0 ] == 1 && stride[1 ] == 1 && dilation[0 ] == 1 &&
55
- dilation[1 ] == 1 && padding[ 0 ] == 0 && padding[ 1 ] == 0 ) {
53
+ dilation[1 ] == 1 ) {
56
54
int col_matrix_width = output_width * output_height;
57
55
int im_size = im_height * im_width;
58
- size_t copy_size = sizeof (T) * output_width;
59
- for (int oh = 0 ; oh < output_height; ++oh) {
60
- const T* im_data_start = im_data + oh * im_width;
61
- T* dst_data = col_data + oh * output_width;
62
- for (int ic = 0 ; ic < im_channels; ++ic) {
63
- const T* src_data = im_data_start + ic * im_size;
64
- for (int kh = 0 ; kh < filter_height; ++kh) {
56
+ if (padding[0 ] == 0 && padding[1 ] == 0 ) {
57
+ size_t copy_size = sizeof (T) * output_width;
58
+ for (int oh = 0 ; oh < output_height; ++oh) {
59
+ const T* im_data_start = im_data + oh * im_width;
60
+ T* dst_data = col_data + oh * output_width;
61
+ for (int ic = 0 ; ic < im_channels; ++ic) {
62
+ const T* src_data = im_data_start + ic * im_size;
63
+ for (int kh = 0 ; kh < filter_height; ++kh) {
64
+ for (int kw = 0 ; kw < filter_width; ++kw) {
65
+ std::memcpy (dst_data, src_data + kw, copy_size);
66
+ dst_data = dst_data + col_matrix_width;
67
+ }
68
+ src_data = src_data + im_width;
69
+ }
70
+ }
71
+ }
72
+ return ;
73
+ } else {
74
+ int plh = padding[0 ];
75
+ // int plw = padding[1];
76
+ int prh =
77
+ (output_height - 1 ) * stride[0 ] + filter_height - im_height - plh;
78
+ // int prw = (output_width - 1) * stride[1] + filter_width - im_width -
79
+ // plw;
80
+
81
+ // fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
82
+ // TODO(TJ): reuse sizes
83
+ assert (plh == prh); // because stride_h == 1
84
+ for (int ph = 0 ; ph < plh; ++ph) {
85
+ size_t sz = sizeof (T) * output_width * (plh - ph);
86
+ T* col_start_l = col_data + ph * filter_width * col_matrix_width;
87
+ T* col_start_r =
88
+ col_data +
89
+ (filter_width - ph - 1 ) * filter_width * col_matrix_width +
90
+ col_matrix_width - output_width * (plh - ph);
91
+ for (int ic = 0 ; ic < im_channels; ++ic) {
92
+ T* dst_data_l =
93
+ col_start_l +
94
+ ic * filter_width * filter_height * col_matrix_width;
95
+ T* dst_data_r =
96
+ col_start_r +
97
+ ic * filter_width * filter_height * col_matrix_width;
65
98
for (int kw = 0 ; kw < filter_width; ++kw) {
66
- std::memcpy (dst_data, src_data + kw, copy_size);
67
- dst_data = dst_data + col_matrix_width;
99
+ std::memset (dst_data_l, 0 , sz);
100
+ std::memset (dst_data_r, 0 , sz);
101
+ dst_data_l = dst_data_l + col_matrix_width;
102
+ dst_data_r = dst_data_r + col_matrix_width;
68
103
}
69
- src_data = src_data + im_width;
70
104
}
71
105
}
106
+ return ;
72
107
}
73
- return ;
74
108
}
75
109
76
110
for (int c = 0 ; c < channels_col; ++c) {
0 commit comments