@@ -72,11 +72,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
72
72
return ;
73
73
} else {
74
74
int plh = padding[0 ];
75
- // int plw = padding[1];
75
+ int plw = padding[1 ];
76
76
int prh =
77
77
(output_height - 1 ) * stride[0 ] + filter_height - im_height - plh;
78
- // int prw = (output_width - 1) * stride[1] + filter_width - im_width -
79
- // plw;
78
+ int prw =
79
+ (output_width - 1 ) * stride[ 1 ] + filter_width - im_width - plw;
80
80
81
81
// fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
82
82
// TODO(TJ): refine ph*xxx
@@ -100,6 +100,64 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
100
100
}
101
101
}
102
102
}
103
+
104
+ // fill width padding
105
+ assert (plw == prw); // because stride_w == 1
106
+ if (plw == 1 ) {
107
+ auto pad = static_cast <T>(0 ); // padding zero
108
+ for (int ic = 0 ; ic < im_channels; ++ic) {
109
+ // TODO(TJ): use add and resue stride
110
+ T* dst_data_ic = col_data + ic * col_block_ic;
111
+ for (int kh = 0 ; kh < filter_height; ++kh) {
112
+ T* dst_data_kh = dst_data_ic + kh * col_block_fh;
113
+ for (T* dst_data :
114
+ {dst_data_kh, dst_data_kh +
115
+ (filter_width - prw) * col_matrix_width +
116
+ output_width - 1 }) {
117
+ // TODO(TJ): from plh, saving repeated assignment
118
+ for (int oh = 0 ; oh < output_height; ++oh) {
119
+ *dst_data = pad;
120
+ dst_data = dst_data + output_width;
121
+ }
122
+ }
123
+ }
124
+ }
125
+ } else {
126
+ // padding_size > 1
127
+ for (int ic = 0 ; ic < im_channels; ++ic) {
128
+ // TODO(TJ): use add and resue stride
129
+ T* dst_data_ic =
130
+ col_data + ic * filter_width * filter_height * col_matrix_width;
131
+ for (int kh = 0 ; kh < filter_height; ++kh) {
132
+ T* dst_data_kh =
133
+ dst_data_ic + kh * filter_width * col_matrix_width;
134
+ for (int kw = 0 ; kw < plw; ++kw) {
135
+ // TODO(TJ): reuse array outside this for
136
+ size_t sz = sizeof (T) * (plw - kw);
137
+ T* dst_data = dst_data_kh + kw * col_matrix_width;
138
+ // TODO(TJ): from plh, saving repeated assignment
139
+ for (int oh = 0 ; oh < output_height; ++oh) {
140
+ std::memset (dst_data, 0 , sz);
141
+ dst_data = dst_data + output_width;
142
+ }
143
+ }
144
+ // TODO(TJ): use reverse to save cache
145
+ for (int kw = 0 ; kw < prw; ++kw) {
146
+ // TODO(TJ): reuse array outside this for
147
+ auto num = (prw - kw);
148
+ size_t sz = sizeof (T) * num;
149
+ T* dst_data = dst_data_kh +
150
+ (filter_width - 1 - kw) * col_matrix_width +
151
+ output_width - num;
152
+ // TODO(TJ): from plh, saving repeated assignment
153
+ for (int oh = 0 ; oh < output_height; ++oh) {
154
+ std::memset (dst_data, 0 , sz);
155
+ dst_data = dst_data + output_width;
156
+ }
157
+ }
158
+ }
159
+ }
160
+ }
103
161
return ;
104
162
}
105
163
}
0 commit comments