@@ -99,20 +99,20 @@ class GemmConvKernel : public framework::OpKernel<T> {
99
99
// use col_shape in the im2col calculation
100
100
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
101
101
// o_h, o_w}
102
- std::vector< int64_t > col_shape_vec ( filter_shape_vec.size () +
103
- output_shape_vec. size () - 3 );
104
- col_shape_vec. assign ( 1 , input->dims ()[1 ] / groups) ;
105
- col_shape_vec. insert (col_shape_vec. end (), filter_shape_vec. begin () + 2 ,
106
- filter_shape_vec. end ()) ;
107
- col_shape_vec. insert (col_shape_vec. end (), output_shape_vec. begin () + 2 ,
108
- output_shape_vec. end ());
102
+ size_t data_dim = filter_shape_vec.size () - 2 ;
103
+ std::vector< int64_t > col_shape_vec ( 1 + 2 * data_dim );
104
+ col_shape_vec[ 0 ] = input->dims ()[1 ] / groups;
105
+ for ( size_t j = 0 ; j < data_dim; ++j) {
106
+ col_shape_vec[j + 1 ] = filter_shape_vec[j + 2 ] ;
107
+ col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2 ];
108
+ }
109
109
framework::DDim col_shape (framework::make_ddim (col_shape_vec));
110
110
111
111
// use col_matrix_shape in the gemm calculation
112
112
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
113
113
// o_h * o_w)
114
114
framework::DDim col_matrix_shape =
115
- framework::flatten_to_2d (col_shape, filter_shape_vec. size () - 2 + 1 );
115
+ framework::flatten_to_2d (col_shape, data_dim + 1 );
116
116
117
117
bool is_expand = IsExpand (filter_shape_vec, strides, paddings, dilations);
118
118
Tensor col;
@@ -155,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
155
155
col.ShareDataWith (in_slice);
156
156
col_matrix.ShareDataWith (col);
157
157
col_matrix.Resize (col_matrix_shape);
158
- } else if (filter_shape_vec. size () == 4 ) {
158
+ } else if (data_dim == 2U ) {
159
159
// im2col
160
160
im2col (context.device_context (), in_slice, dilations, strides,
161
161
std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
162
162
paddings[1 ]},
163
163
&col);
164
- } else if (filter_shape_vec. size () == 5 ) {
164
+ } else if (data_dim == 3U ) {
165
165
// vol2col
166
166
vol2col (context.device_context (), in_slice, dilations, strides,
167
167
paddings, &col);
@@ -211,21 +211,21 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
211
211
// use col_shape in the im2col calculation
212
212
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
213
213
// o_h, o_w}
214
- std::vector< int64_t > col_shape_vec ( filter_shape_vec.size () +
215
- output_shape_vec. size () - 3 );
216
- col_shape_vec. assign ( 1 , input->dims ()[1 ] / groups) ;
217
- col_shape_vec. insert (col_shape_vec. end (), filter_shape_vec. begin () + 2 ,
218
- filter_shape_vec. end ()) ;
219
- col_shape_vec. insert (col_shape_vec. end (), output_shape_vec. begin () + 2 ,
220
- output_shape_vec. end ());
214
+ size_t data_dim = filter_shape_vec.size () - 2 ;
215
+ std::vector< int64_t > col_shape_vec ( 1 + 2 * data_dim );
216
+ col_shape_vec[ 0 ] = input->dims ()[1 ] / groups;
217
+ for ( size_t j = 0 ; j < data_dim; ++j) {
218
+ col_shape_vec[j + 1 ] = filter_shape_vec[j + 2 ] ;
219
+ col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2 ];
220
+ }
221
221
framework::DDim col_shape (framework::make_ddim (col_shape_vec));
222
222
223
223
// use col_matrix_shape in the gemm calculation
224
224
// size: (i_c/g * k_h * k_w, o_h * o_w)
225
225
// or
226
226
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
227
227
framework::DDim col_matrix_shape =
228
- framework::flatten_to_2d (col_shape, filter_shape_vec. size () - 2 + 1 );
228
+ framework::flatten_to_2d (col_shape, data_dim + 1 );
229
229
230
230
framework::DDim input_shape = framework::slice_ddim (
231
231
input->dims (), 1 , static_cast <int >(input->dims ().size ()));
@@ -286,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
286
286
out_grad_slice, false , T (1.0 ), &col_matrix,
287
287
T (0.0 ));
288
288
289
- if (is_expand && filter_shape_vec. size () == 4 ) {
289
+ if (is_expand && data_dim == 2U ) {
290
290
col2im (context.device_context (), col, dilations, strides,
291
291
std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
292
292
paddings[1 ]},
293
293
&in_grad_slice);
294
- } else if (is_expand && filter_shape_vec. size () == 5 ) {
294
+ } else if (is_expand && data_dim == 3U ) {
295
295
col2vol (context.device_context (), col, dilations, strides, paddings,
296
296
&in_grad_slice);
297
297
}
@@ -320,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
320
320
col.ShareDataWith (in_slice);
321
321
col_matrix.ShareDataWith (col);
322
322
col_matrix.Resize (col_matrix_shape);
323
- } else if (filter_shape_vec. size () == 4 ) {
323
+ } else if (data_dim == 2U ) {
324
324
im2col (context.device_context (), in_slice, dilations, strides,
325
325
std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
326
326
paddings[1 ]},
327
327
&col);
328
- } else if (filter_shape_vec. size () == 5 ) {
328
+ } else if (data_dim == 3U ) {
329
329
vol2col (context.device_context (), in_slice, dilations, strides,
330
330
paddings, &col);
331
331
}
0 commit comments