@@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
38
38
std::vector<int >& dilations) {
39
39
bool filter_1 = true , strides_1 = true , padding_0 = true , dilation_1 = true ;
40
40
for (size_t j = 0 ; j < strides.size (); ++j) {
41
- filter_1 = filter_1 && (static_cast <int >(filter_dim[j]) == 1 );
41
+ filter_1 = filter_1 && (static_cast <int >(filter_dim[j + 2 ]) == 1 );
42
42
strides_1 = strides_1 && (strides[j] == 1 );
43
43
padding_0 = padding_0 && (paddings[j] == 0 );
44
44
dilation_1 = dilation_1 && (dilations[j] == 1 );
@@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> {
91
91
92
92
const int batch_size = static_cast <int >(input->dims ()[0 ]);
93
93
94
- // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
94
+ // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
95
95
std::vector<int64_t > filter_shape_vec (framework::vectorize (filter.dims ()));
96
- filter_shape_vec.erase (filter_shape_vec.begin (),
97
- filter_shape_vec.begin () + 2 );
98
-
99
- // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
96
+ // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
100
97
std::vector<int64_t > output_shape_vec (framework::vectorize (output->dims ()));
101
- output_shape_vec.erase (output_shape_vec.begin (),
102
- output_shape_vec.begin () + 2 );
103
98
104
99
// use col_shape in the im2col calculation
105
100
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
106
101
// o_h, o_w}
107
- std::vector<int64_t > col_shape_vec;
108
- col_shape_vec.push_back (input->dims ()[1 ] / groups);
109
- col_shape_vec.insert (col_shape_vec.end (), filter_shape_vec.begin (),
110
- filter_shape_vec.end ());
111
- col_shape_vec.insert (col_shape_vec.end (), output_shape_vec.begin (),
112
- output_shape_vec.end ());
102
+ size_t data_dim = filter_shape_vec.size () - 2 ;
103
+ std::vector<int64_t > col_shape_vec (1 + 2 * data_dim);
104
+ col_shape_vec[0 ] = input->dims ()[1 ] / groups;
105
+ for (size_t j = 0 ; j < data_dim; ++j) {
106
+ col_shape_vec[j + 1 ] = filter_shape_vec[j + 2 ];
107
+ col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2 ];
108
+ }
113
109
framework::DDim col_shape (framework::make_ddim (col_shape_vec));
114
110
115
111
// use col_matrix_shape in the gemm calculation
116
112
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
117
113
// o_h * o_w)
118
114
framework::DDim col_matrix_shape =
119
- framework::flatten_to_2d (col_shape, filter_shape_vec. size () + 1 );
115
+ framework::flatten_to_2d (col_shape, data_dim + 1 );
120
116
121
117
bool is_expand = IsExpand (filter_shape_vec, strides, paddings, dilations);
122
118
Tensor col;
@@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
159
155
col.ShareDataWith (in_slice);
160
156
col_matrix.ShareDataWith (col);
161
157
col_matrix.Resize (col_matrix_shape);
162
- } else if (filter_shape_vec. size () == 2 ) {
158
+ } else if (data_dim == 2U ) {
163
159
// im2col
164
160
im2col (context.device_context (), in_slice, dilations, strides,
165
161
std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
166
162
paddings[1 ]},
167
163
&col);
168
- } else if (filter_shape_vec. size () == 3 ) {
164
+ } else if (data_dim == 3U ) {
169
165
// vol2col
170
166
vol2col (context.device_context (), in_slice, dilations, strides,
171
167
paddings, &col);
@@ -206,34 +202,30 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
206
202
207
203
const int batch_size = static_cast <int >(input->dims ()[0 ]);
208
204
209
- // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
205
+ // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
210
206
std::vector<int64_t > filter_shape_vec (framework::vectorize (filter.dims ()));
211
- filter_shape_vec.erase (filter_shape_vec.begin (),
212
- filter_shape_vec.begin () + 2 );
213
-
214
- // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
207
+ // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
215
208
std::vector<int64_t > output_shape_vec (
216
209
framework::vectorize (output_grad->dims ()));
217
- output_shape_vec.erase (output_shape_vec.begin (),
218
- output_shape_vec.begin () + 2 );
219
210
220
211
// use col_shape in the im2col calculation
221
212
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
222
213
// o_h, o_w}
223
- std::vector<int64_t > col_shape_vec;
224
- col_shape_vec.push_back (input->dims ()[1 ] / groups);
225
- col_shape_vec.insert (col_shape_vec.end (), filter_shape_vec.begin (),
226
- filter_shape_vec.end ());
227
- col_shape_vec.insert (col_shape_vec.end (), output_shape_vec.begin (),
228
- output_shape_vec.end ());
214
+ size_t data_dim = filter_shape_vec.size () - 2 ;
215
+ std::vector<int64_t > col_shape_vec (1 + 2 * data_dim);
216
+ col_shape_vec[0 ] = input->dims ()[1 ] / groups;
217
+ for (size_t j = 0 ; j < data_dim; ++j) {
218
+ col_shape_vec[j + 1 ] = filter_shape_vec[j + 2 ];
219
+ col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2 ];
220
+ }
229
221
framework::DDim col_shape (framework::make_ddim (col_shape_vec));
230
222
231
223
// use col_matrix_shape in the gemm calculation
232
224
// size: (i_c/g * k_h * k_w, o_h * o_w)
233
225
// or
234
226
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
235
227
framework::DDim col_matrix_shape =
236
- framework::flatten_to_2d (col_shape, filter_shape_vec. size () + 1 );
228
+ framework::flatten_to_2d (col_shape, data_dim + 1 );
237
229
238
230
framework::DDim input_shape = framework::slice_ddim (
239
231
input->dims (), 1 , static_cast <int >(input->dims ().size ()));
@@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
294
286
out_grad_slice, false , T (1.0 ), &col_matrix,
295
287
T (0.0 ));
296
288
297
- if (is_expand && filter_shape_vec. size () == 2 ) {
289
+ if (is_expand && data_dim == 2U ) {
298
290
col2im (context.device_context (), col, dilations, strides,
299
291
std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
300
292
paddings[1 ]},
301
293
&in_grad_slice);
302
- } else if (is_expand && filter_shape_vec. size () == 3 ) {
294
+ } else if (is_expand && data_dim == 3U ) {
303
295
col2vol (context.device_context (), col, dilations, strides, paddings,
304
296
&in_grad_slice);
305
297
}
@@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
328
320
col.ShareDataWith (in_slice);
329
321
col_matrix.ShareDataWith (col);
330
322
col_matrix.Resize (col_matrix_shape);
331
- } else if (filter_shape_vec. size () == 2 ) {
323
+ } else if (data_dim == 2U ) {
332
324
im2col (context.device_context (), in_slice, dilations, strides,
333
325
std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
334
326
paddings[1 ]},
335
327
&col);
336
- } else if (filter_shape_vec. size () == 3 ) {
328
+ } else if (data_dim == 3U ) {
337
329
vol2col (context.device_context (), in_slice, dilations, strides,
338
330
paddings, &col);
339
331
}
0 commit comments