@@ -70,7 +70,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
70
70
std::vector<int > strides = context.Attr <std::vector<int >>(" strides" );
71
71
std::vector<int > paddings = context.Attr <std::vector<int >>(" paddings" );
72
72
std::vector<int > dilations = context.Attr <std::vector<int >>(" dilations" );
73
- // groups will alway be disabled in conv2dtranspose.
73
+ int groups = context. Attr < int >( " groups " );
74
74
75
75
const int batch_size = static_cast <int >(input->dims ()[0 ]);
76
76
@@ -81,18 +81,18 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
81
81
82
82
// use col_shape in the im2col and col2im (or vol2col and col2vol)
83
83
// calculation
84
- // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
84
+ // col_shape_vec: {c/g , k_h, k_w, h, w} or {c/g , k_d, k_h, k_w, d, h, w}
85
85
size_t data_dim = filter_shape_vec.size () - 2 ;
86
86
std::vector<int64_t > col_shape_vec (1 + 2 * data_dim);
87
- col_shape_vec[0 ] = output->dims ()[1 ];
87
+ col_shape_vec[0 ] = output->dims ()[1 ] / groups ;
88
88
for (size_t j = 0 ; j < data_dim; ++j) {
89
89
col_shape_vec[j + 1 ] = filter_shape_vec[j + 2 ];
90
90
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2 ];
91
91
}
92
92
DDim col_shape (framework::make_ddim (col_shape_vec));
93
93
94
94
// use col_matrix_shape in the gemm calculation
95
- // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
95
+ // size: (c/g * k_h * k_w, h * w) or (c/g * k_d * k_h * k_w, d * h * w)
96
96
DDim col_matrix_shape = framework::flatten_to_2d (col_shape, data_dim + 1 );
97
97
98
98
Tensor col;
@@ -111,7 +111,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
111
111
// input matrix size: (m, h * w) or (m, d * h * w)
112
112
DDim input_matrix_shape = {input->dims ()[1 ], col_matrix_shape[1 ]};
113
113
114
- // filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w)
114
+ // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w)
115
115
DDim filter_matrix_shape = {input->dims ()[1 ], col_matrix_shape[0 ]};
116
116
filter.Resize (filter_matrix_shape);
117
117
@@ -121,6 +121,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
121
121
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
122
122
set_zero (dev_ctx, output, static_cast <T>(0 ));
123
123
124
+ int in_step = static_cast <int >(input->dims ()[1 ]) / groups;
125
+ int out_step = static_cast <int >(output->dims ()[1 ]) / groups;
124
126
math::Col2ImFunctor<math::ColFormat::kCFO , DeviceContext, T> col2im;
125
127
math::Col2VolFunctor<DeviceContext, T> col2vol;
126
128
@@ -133,22 +135,29 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
133
135
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
134
136
Tensor output_batch = output->Slice (i, i + 1 ).Resize (output_shape);
135
137
136
- // col_matrix = filter * input_batch
137
- // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
138
- blas.MatMul (filter, true , input_batch, false , static_cast <T>(1.0 ),
139
- &col_matrix, static_cast <T>(0.0 ));
140
-
141
- if (data_dim == 2U ) {
142
- // col2im: col_matrix -> dy
143
- // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
144
- col2im (dev_ctx, col, dilations, strides,
145
- std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
146
- paddings[1 ]},
147
- &output_batch);
148
- } else if (data_dim == 3U ) {
149
- // col2vol: col_matrix -> dy
150
- // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
151
- col2vol (dev_ctx, col, dilations, strides, paddings, &output_batch);
138
+ for (int g = 0 ; g < groups; g++) {
139
+ Tensor in_slice = input_batch.Slice (g * in_step, (g + 1 ) * in_step);
140
+ Tensor filter_slice = filter.Slice (g * in_step, (g + 1 ) * in_step);
141
+ Tensor out_slice = output_batch.Slice (g * out_step, (g + 1 ) * out_step);
142
+
143
+ // col_matrix = filter_slice * input_slice
144
+ // of shape (c/g * k_h * k_w, h * w)
145
+ // or (c/g * k_d * k_h * k_w, d * h * w)
146
+ blas.MatMul (filter_slice, true , in_slice, false , static_cast <T>(1.0 ),
147
+ &col_matrix, static_cast <T>(0.0 ));
148
+
149
+ if (data_dim == 2U ) {
150
+ // col2im: col_matrix -> dy
151
+ // from (c/g * k_h * k_w, h * w) to (c/g, o_h, o_w)
152
+ col2im (dev_ctx, col, dilations, strides,
153
+ std::vector<int >{paddings[0 ], paddings[1 ], paddings[0 ],
154
+ paddings[1 ]},
155
+ &out_slice);
156
+ } else if (data_dim == 3U ) {
157
+ // col2vol: col_matrix -> dy
158
+ // from (c/g * k_d * k_h * k_w, d * h * w) to (c/g, o_d, o_h, o_w)
159
+ col2vol (dev_ctx, col, dilations, strides, paddings, &out_slice);
160
+ }
152
161
}
153
162
}
154
163
}
@@ -174,6 +183,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
174
183
std::vector<int > strides = context.Attr <std::vector<int >>(" strides" );
175
184
std::vector<int > paddings = context.Attr <std::vector<int >>(" paddings" );
176
185
std::vector<int > dilations = context.Attr <std::vector<int >>(" dilations" );
186
+ int groups = context.Attr <int >(" groups" );
177
187
178
188
const int batch_size = static_cast <int >(input->dims ()[0 ]);
179
189
@@ -205,9 +215,11 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
205
215
// input matrix size: (m, h * w) or (m, d * h * w)
206
216
DDim input_matrix_shape = {input->dims ()[1 ], col_matrix_shape[1 ]};
207
217
208
- // filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w)
209
- DDim filter_matrix_shape = {input->dims ()[1 ], col_matrix_shape[0 ]};
218
+ // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w)
219
+ DDim filter_matrix_shape = {input->dims ()[1 ], col_matrix_shape[0 ] / groups };
210
220
filter.Resize (filter_matrix_shape);
221
+ int in_step = static_cast <int >(input->dims ()[1 ]) / groups;
222
+ int col_step = static_cast <int >(col_matrix_shape[0 ]) / groups;
211
223
212
224
// convolution transpose grad on input:
213
225
// im2col + gemm (similar to conv-forward)
@@ -233,7 +245,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
233
245
if (input_grad) {
234
246
input_grad->mutable_data <T>(context.GetPlace ());
235
247
}
236
- if (filter_grad) { // filter size (m, c, k_h, k_w)
248
+ if (filter_grad) { // filter size (m, c/g , k_h, k_w)
237
249
filter_grad->mutable_data <T>(context.GetPlace ());
238
250
set_zero (dev_ctx, filter_grad, static_cast <T>(0 ));
239
251
filter_grad_ = *filter_grad;
@@ -268,8 +280,17 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
268
280
// or
269
281
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
270
282
// d, h, w)
271
- blas.MatMul (filter, false , col_matrix, false , static_cast <T>(1.0 ),
272
- &input_grad_batch, static_cast <T>(0.0 ));
283
+ for (int g = 0 ; g < groups; g++) {
284
+ Tensor input_grad_slice =
285
+ input_grad_batch.Slice (g * in_step, (g + 1 ) * in_step);
286
+ Tensor filter_slice = filter.Slice (g * in_step, (g + 1 ) * in_step);
287
+ Tensor col_matrix_slice =
288
+ col_matrix.Slice (g * col_step, (g + 1 ) * col_step);
289
+
290
+ blas.MatMul (filter_slice, false , col_matrix_slice, false ,
291
+ static_cast <T>(1.0 ), &input_grad_slice,
292
+ static_cast <T>(0.0 ));
293
+ }
273
294
}
274
295
if (filter_grad) {
275
296
// input batch
@@ -279,8 +300,17 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
279
300
// or
280
301
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
281
302
// k_h * k_w)
282
- blas.MatMul (in_batch, false , col_matrix, true , static_cast <T>(1.0 ),
283
- &filter_grad_, static_cast <T>(1.0 ));
303
+ for (int g = 0 ; g < groups; g++) {
304
+ Tensor in_batch_slice =
305
+ in_batch.Slice (g * in_step, (g + 1 ) * in_step);
306
+ Tensor filter_grad_slice =
307
+ filter_grad_.Slice (g * in_step, (g + 1 ) * in_step);
308
+ Tensor col_matrix_slice =
309
+ col_matrix.Slice (g * col_step, (g + 1 ) * col_step);
310
+ blas.MatMul (in_batch_slice, false , col_matrix_slice, true ,
311
+ static_cast <T>(1.0 ), &filter_grad_slice,
312
+ static_cast <T>(1.0 ));
313
+ }
284
314
}
285
315
}
286
316
}
0 commit comments