@@ -180,11 +180,327 @@ Tensor _convolution_decomp(
180
180
// static auto op = c10::Dispatcher::singleton()
181
181
// .findSchemaOrThrow("aten::cudnn_convolution_backward", "");
182
182
// return slow_fallback<Tensor,Tensor>(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask });
183
- // }
183
+
184
+ static Tensor compute_grad_bias (
185
+ const Tensor& grad_output_, std::array<bool , 3 > output_mask) {
186
+ if (!output_mask[2 ]) {
187
+ return Tensor ();
188
+ }
189
+ DimVector reduce_dims;
190
+ reduce_dims.resize (grad_output_.dim () - 1 );
191
+ reduce_dims[0 ] = 0 ;
192
+ std::iota (reduce_dims.begin () + 1 , reduce_dims.end (), 2 );
193
+ return grad_output_.sum (reduce_dims);
194
+ }
195
+
196
+ // reshapes the batch_size into dim
197
+ Tensor make_dummy (
198
+ const Tensor& tensor, optional<int64_t > tensor_bdim,
199
+ int64_t dim, int64_t batch_size) {
200
+ auto tensor_ = tensor_bdim ? tensor.select (*tensor_bdim, 0 ) : tensor;
201
+ auto orig_size = tensor_.size (dim);
202
+ tensor_ = tensor_.slice (dim, 0 , 1 );
203
+
204
+ DimVector expand_shape (tensor_.sizes ().begin (), tensor_.sizes ().end ());
205
+ expand_shape[dim] = batch_size * orig_size;
206
+
207
+ // return tensor_.new_zeros(expand_shape);
208
+ return at::_efficientzerotensor (expand_shape, tensor.options ());
209
+ }
210
+
211
+ std::tuple<Tensor,optional<int64_t >>
212
+ convolution_backward_input_batch_rule (
213
+ const Tensor& grad_output, optional<int64_t > grad_output_bdim,
214
+ const Tensor& input, optional<int64_t > input_bdim,
215
+ const Tensor& weight, optional<int64_t > weight_bdim,
216
+ IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
217
+ IntArrayRef output_padding, int64_t groups) {
218
+ const std::array<bool , 3 > mask = {true , false , false };
219
+ if (grad_output_bdim && weight_bdim) {
220
+ // regular: BNO, BOI -> N(BO), (BO)I -> N(BI)
221
+ // transposed: BNO, BIO -> N(BO), (BI)O -> N(BI)
222
+ const auto batch_size = weight.size (*weight_bdim);
223
+ const auto grad_output_ = reshape_dim_into (*grad_output_bdim, 1 , grad_output);
224
+ const auto weight_ = reshape_dim_into (*weight_bdim, 0 , weight);
225
+ auto dummy_input = make_dummy (input, input_bdim, 1 , batch_size);
226
+ const auto result = at::convolution_backward (
227
+ grad_output_, dummy_input, weight_, nullopt, stride, padding,
228
+ dilation, transposed, output_padding, groups * batch_size, mask);
229
+ const auto grad_input = reshape_dim_outof (1 , batch_size, std::get<0 >(result));
230
+ return std::make_tuple (grad_input, 1 );
231
+ } else if (grad_output_bdim && !weight_bdim) {
232
+ // BNO, OI -> (BN)O, OI -> (BN)I
233
+ // transposed is the same.
234
+ const auto batch_size = grad_output.size (*grad_output_bdim);
235
+ const auto grad_output_ = reshape_dim_into (*grad_output_bdim, 0 , grad_output);
236
+ auto dummy_input = make_dummy (input, input_bdim, 0 , batch_size);
237
+ const auto result = at::convolution_backward (
238
+ grad_output_, dummy_input, weight, nullopt, stride, padding,
239
+ dilation, transposed, output_padding, groups, mask);
240
+ const auto grad_input = reshape_dim_outof (0 , batch_size, std::get<0 >(result));
241
+ return std::make_tuple (grad_input, 0 );
242
+ } else if (!grad_output_bdim && weight_bdim) {
243
+ const auto batch_size = weight.size (*weight_bdim);
244
+ if (groups == 1 ) {
245
+ // regular: NO, BOI -> NO, O(BI) -> N(BI)
246
+ // transposed: NO, BIO -> NO, (BI)O -> N(BI)
247
+ const auto in_ch_dim = transposed ? 0 : 1 ;
248
+ const auto weight_ = reshape_dim_into (*weight_bdim, in_ch_dim, weight);
249
+ auto dummy_input = make_dummy (input, input_bdim, 1 , batch_size);
250
+ const auto result = at::convolution_backward (
251
+ grad_output, dummy_input, weight_, nullopt, stride, padding,
252
+ dilation, transposed, output_padding, groups, mask);
253
+ const auto grad_input = reshape_dim_outof (1 , batch_size, std::get<0 >(result));
254
+ return std::make_tuple (grad_input, 1 );
255
+ }
256
+ Tensor grad_input;
257
+ if (!transposed) {
258
+ // N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI)
259
+ const auto weight_ = reshape_dim_into (*weight_bdim, 1 , weight);
260
+ auto dummy_input = make_dummy (input, input_bdim, 1 , batch_size);
261
+ const auto result = at::convolution_backward (
262
+ grad_output, dummy_input, weight_, nullopt, stride, padding,
263
+ dilation, transposed, output_padding, groups, mask);
264
+ grad_input = std::get<0 >(result); // N(GBI)
265
+ } else {
266
+ // N(GO), B(GI)O -> N(GO), (GBI)O -> N(GBI)
267
+ auto weight_ = moveBatchDimToFront (weight, weight_bdim); // B(GI)O
268
+ weight_ = reshape_dim_outof (1 , groups, weight_); // BGIO
269
+ weight_ = weight_.transpose (0 , 1 ); // GBIO
270
+ weight_ = weight_.flatten (0 , 2 ); // (GBI)O
271
+ const auto dummy_input = make_dummy (input, input_bdim, 1 , batch_size);
272
+ const auto result = at::convolution_backward (
273
+ grad_output, dummy_input, weight_, nullopt, stride, padding,
274
+ dilation, transposed, output_padding, groups, mask);
275
+ grad_input = std::get<0 >(result); // N(GBI)
276
+ }
277
+ // N(GBI) -> NG(BI) -> NGBI -> NBGI -> NB(GI)
278
+ grad_input = reshape_dim_outof (1 , groups, grad_input);
279
+ grad_input = reshape_dim_outof (2 , batch_size, grad_input);
280
+ grad_input = grad_input.transpose (1 , 2 );
281
+ grad_input = reshape_dim_into (2 , 2 , grad_input);
282
+ return std::make_tuple (grad_input, 1 );
283
+ } else {
284
+ TORCH_INTERNAL_ASSERT (input_bdim);
285
+ const auto dummy_input = make_dummy (input, input_bdim, 0 , 1 );
286
+ const auto result = at::convolution_backward (
287
+ grad_output, dummy_input, weight, nullopt, stride, padding,
288
+ dilation, transposed, output_padding, groups, mask);
289
+ return std::make_tuple (std::get<0 >(result), nullopt);
290
+ }
291
+ }
292
+ std::tuple<Tensor,optional<int64_t >>
293
+ convolution_backward_weight_batch_rule (
294
+ const Tensor& grad_output, optional<int64_t > grad_output_bdim,
295
+ const Tensor& input, optional<int64_t > input_bdim,
296
+ const Tensor& weight, optional<int64_t > weight_bdim,
297
+ IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
298
+ IntArrayRef output_padding, int64_t groups) {
299
+ const std::array<bool , 3 > mask = {false , true , false };
300
+ if (grad_output_bdim && input_bdim) {
301
+ // BNO, BNI -> N(BO), N(BI) -> (BO)I (regular) (BI)O (transposed)
302
+ const auto batch_size = input.size (*input_bdim);
303
+ const auto grad_output_ = reshape_dim_into (*grad_output_bdim, 1 , grad_output);
304
+ const auto input_ = reshape_dim_into (*input_bdim, 1 , input);
305
+ const auto dummy_weight = make_dummy (weight, weight_bdim, 0 , batch_size);
306
+ const auto result = at::convolution_backward (
307
+ grad_output_, input_, dummy_weight, nullopt, stride, padding,
308
+ dilation, transposed, output_padding, groups * batch_size, mask);
309
+ auto grad_weight = std::get<1 >(result);
310
+ grad_weight = reshape_dim_outof (0 , batch_size, grad_weight);
311
+ return std::make_tuple (grad_weight, 0 );
312
+ } else if (grad_output_bdim && !input_bdim) {
313
+ const auto batch_size = grad_output.size (*grad_output_bdim);
314
+ if (groups == 1 ) {
315
+ // regular: BNO, NI -> N(BO), NI -> (BO)I
316
+ // transposed: BNO, NI -> N(BO), NI -> I(BO)
317
+ const auto grad_output_ = reshape_dim_into (*grad_output_bdim, 1 , grad_output);
318
+ const auto out_ch_dim = transposed ? 1 : 0 ;
319
+ const auto dummy_weight = make_dummy (weight, weight_bdim, out_ch_dim, batch_size);
320
+ const auto result = at::convolution_backward (
321
+ grad_output_, input, dummy_weight, nullopt, stride, padding,
322
+ dilation, transposed, output_padding, groups, mask);
323
+ auto grad_weight = std::get<1 >(result);
324
+ grad_weight = reshape_dim_outof (out_ch_dim, batch_size, grad_weight);
325
+ return std::make_tuple (grad_weight, out_ch_dim);
326
+ } else {
327
+ auto grad_output_ = moveBatchDimToFront (grad_output, grad_output_bdim); // BN(GO)
328
+ grad_output_ = reshape_dim_outof (2 , groups, grad_output_); // BNGO
329
+ grad_output_ = grad_output_.movedim (0 , 2 ); // NGBO
330
+ grad_output_ = grad_output_.flatten (1 , 3 ); // N(GBO)
331
+ if (!transposed) {
332
+ // BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I
333
+ const auto dummy_weight = make_dummy (weight, weight_bdim, 0 , batch_size);
334
+ const auto result = at::convolution_backward (
335
+ grad_output_, input, dummy_weight, nullopt, stride, padding,
336
+ dilation, transposed, output_padding, groups, mask);
337
+ auto grad_weight = std::get<1 >(result);
338
+ grad_weight = grad_weight.unflatten (0 , { groups, batch_size, -1 }); // GBOI
339
+ grad_weight = grad_weight.transpose (0 , 1 ); // BGOI
340
+ grad_weight = grad_weight.flatten (1 , 2 ); // B(GO)I
341
+ return std::make_tuple (grad_weight, 0 );
342
+ } else {
343
+ // BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO)
344
+ const auto dummy_weight = make_dummy (weight, weight_bdim, 1 , batch_size);
345
+ const auto result = at::convolution_backward (
346
+ grad_output_, input, dummy_weight, nullopt, stride, padding,
347
+ dilation, transposed, output_padding, groups, mask);
348
+ auto grad_weight = std::get<1 >(result);
349
+ grad_weight = reshape_dim_outof (1 , batch_size, grad_weight);
350
+ return std::make_tuple (grad_weight, 1 );
351
+ }
352
+ }
353
+ } else if (!grad_output_bdim && input_bdim) {
354
+ const auto batch_size = input.size (*input_bdim);
355
+ if (groups == 1 ) {
356
+ // regular: NO, BNI -> NO, N(BI) -> O(BI)
357
+ // transposed: NO, BNI -> NO, N(BI) -> (BI)O
358
+ const auto input_ = reshape_dim_into (*input_bdim, 1 , input);
359
+ const auto in_ch_dim = transposed ? 0 : 1 ;
360
+ const auto dummy_weight = make_dummy (weight, weight_bdim, in_ch_dim, batch_size);
361
+ const auto result = at::convolution_backward (
362
+ grad_output, input_, dummy_weight, nullopt, stride, padding,
363
+ dilation, transposed, output_padding, groups, mask);
364
+ auto grad_weight = std::get<1 >(result);
365
+ grad_weight = reshape_dim_outof (in_ch_dim, batch_size, grad_weight);
366
+ return std::make_tuple (grad_weight, in_ch_dim);
367
+ } else {
368
+ auto input_ = moveBatchDimToFront (input, input_bdim); // BN(GI)
369
+ input_ = reshape_dim_outof (2 , groups, input_); // BNGI
370
+ input_ = input_.movedim (0 , 2 ); // NGBI
371
+ input_ = input_.flatten (1 , 3 ); // N(GBI)
372
+ if (!transposed) {
373
+ // regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
374
+ const auto dummy_weight = make_dummy (weight, weight_bdim, 1 , batch_size);
375
+ const auto result = at::convolution_backward (
376
+ grad_output, input_, dummy_weight, nullopt, stride, padding,
377
+ dilation, transposed, output_padding, groups, mask);
378
+ auto grad_weight = std::get<1 >(result);
379
+ grad_weight = reshape_dim_outof (1 , batch_size, grad_weight);
380
+ return std::make_tuple (grad_weight, 1 );
381
+ } else {
382
+ // transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O
383
+ const auto dummy_weight = make_dummy (weight, weight_bdim, 0 , batch_size);
384
+ const auto result = at::convolution_backward (
385
+ grad_output, input_, dummy_weight, nullopt, stride, padding,
386
+ dilation, transposed, output_padding, groups, mask);
387
+ auto grad_weight = std::get<1 >(result);
388
+ grad_weight = grad_weight.unflatten (0 , { groups, batch_size, -1 }); // GBIO
389
+ grad_weight = grad_weight.transpose (0 , 1 ); // BGIO
390
+ grad_weight = grad_weight.flatten (1 , 2 ); // B(GI)O
391
+ return std::make_tuple (grad_weight, 0 );
392
+ }
393
+ }
394
+ } else {
395
+ TORCH_INTERNAL_ASSERT (weight_bdim);
396
+ const auto dummy_weight = make_dummy (weight, weight_bdim, 0 , 1 );
397
+ const auto result = at::convolution_backward (
398
+ grad_output, input, dummy_weight, nullopt, stride, padding,
399
+ dilation, transposed, output_padding, groups, mask);
400
+ return std::make_tuple (std::get<1 >(result), nullopt);
401
+
402
+ }
403
+ }
404
+
405
+ std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing (
406
+ const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
407
+ const c10::optional<IntArrayRef> bias_sizes_opt,
408
+ IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
409
+ IntArrayRef output_padding, int64_t groups, std::array<bool , 3 > output_mask) {
410
+ const auto maybe_layer = maybeCurrentDynamicLayer ();
411
+ TORCH_INTERNAL_ASSERT (maybe_layer.has_value ());
412
+ int64_t cur_level = maybe_layer->layerId ();
413
+ Tensor grad_output;
414
+ optional<int64_t > grad_output_bdim;
415
+ std::tie (grad_output, grad_output_bdim) = unwrapTensorAtLevel (grad_output_, cur_level);
416
+ Tensor input;
417
+ optional<int64_t > input_bdim;
418
+ std::tie (input, input_bdim) = unwrapTensorAtLevel (input_, cur_level);
419
+ Tensor weight;
420
+ optional<int64_t > weight_bdim;
421
+ std::tie (weight, weight_bdim) = unwrapTensorAtLevel (weight_, cur_level);
422
+
423
+ const auto grad_bias = compute_grad_bias (grad_output_, output_mask);
424
+ output_mask[2 ] = false ;
425
+
426
+ // TODO: A little bird says that unfold + matmul is actually faster than
427
+ // group convolution in many cases. We should benchmark some of
428
+ // the common cases and replace things with unfold + matmul as necessary.
429
+
430
+ // Notation:
431
+ // B - a batch dimension
432
+ // G - groups (sometimes omitted because it doesn't matter)
433
+ // NO - grad_output
434
+ // NI - input
435
+ // OI - weight
436
+ // "(BO)I" - we don't actually care about the values of this Tensor,
437
+ // we just need to create a tensor on the same device with the
438
+ // correct shape and pray that the implementation is smart enough
439
+ // to not do anything with it.
440
+
441
+ // BNO, BNI, BOI
442
+ // AKA one of the model ensembling case
443
+ if (grad_output_bdim && input_bdim && weight_bdim) {
444
+ c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
445
+ grad_output = reshape_dim_into (*grad_output_bdim, 1 , grad_output);
446
+
447
+ // BNO, BNI, BOI -> N(BO), N(BI), (BO)I
448
+ const auto batch_size = weight.size (*weight_bdim);
449
+ input = reshape_dim_into (*input_bdim, 1 , input);
450
+ weight = reshape_dim_into (*weight_bdim, 0 , weight);
451
+ const auto result = at::convolution_backward (
452
+ grad_output, input, weight, nullopt, stride, padding, dilation,
453
+ transposed, output_padding, batch_size * groups, output_mask);
454
+ // N(BI), (BO)I -> NBI, BOI
455
+ const auto grad_input = output_mask[0 ] ?
456
+ reshape_dim_outof (1 , batch_size, std::get<0 >(result)) : Tensor ();
457
+ const auto grad_weight = output_mask[1 ] ?
458
+ reshape_dim_outof (0 , batch_size, std::get<1 >(result)) : Tensor ();
459
+ return std::make_tuple (
460
+ output_mask[0 ] ? makeBatched (grad_input, 1 , cur_level) : grad_input,
461
+ output_mask[1 ] ? makeBatched (grad_weight, 0 , cur_level) : grad_weight,
462
+ grad_bias);
463
+ }
464
+
465
+ Tensor grad_input;
466
+ if (output_mask[0 ]) {
467
+ c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
468
+ const auto result = convolution_backward_input_batch_rule (
469
+ grad_output, grad_output_bdim,
470
+ input, input_bdim,
471
+ weight, weight_bdim,
472
+ stride, padding, dilation, transposed, output_padding, groups);
473
+ grad_input = makeBatched (std::get<0 >(result), std::get<1 >(result), cur_level);
474
+ }
475
+
476
+ Tensor grad_weight;
477
+ if (output_mask[1 ]) {
478
+ c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
479
+ const auto result = convolution_backward_weight_batch_rule (
480
+ grad_output, grad_output_bdim,
481
+ input, input_bdim,
482
+ weight, weight_bdim,
483
+ stride, padding, dilation, transposed, output_padding, groups);
484
+ grad_weight = makeBatched (std::get<0 >(result), std::get<1 >(result), cur_level);
485
+ }
486
+ return std::make_tuple (grad_input, grad_weight, grad_bias);
487
+
488
+ // Someone's definitely going to find a problem with this batching rule so
489
+ // I'm leaving the following fallback if we need it back.
490
+ // static auto op = c10::Dispatcher::singleton()
491
+ // .findSchemaOrThrow("aten::convolution_backward", "");
492
+ // auto result = slow_fallback<Tensor,Tensor,Tensor>(op, {
493
+ // grad_output_, input_, weight_, bias_sizes_opt,
494
+ // stride, padding, dilation, transposed, output_padding, groups, output_mask
495
+ // });
496
+ // return std::make_tuple(grad_input, std::get<1>(result), grad_bias);
497
+ }
498
+
184
499
185
500
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
186
501
VMAP_SUPPORT (" convolution" , convolution_batch_rule);
187
502
m.impl (" _convolution" , _convolution_decomp);
503
+ m.impl (" convolution_backward" , convolution_backward_plumbing);
188
504
}
189
505
190
506
}} // namespace at;:functorch
0 commit comments