@@ -349,6 +349,39 @@ std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
349
349
return std::make_tuple (result0, mean, rstd);
350
350
}
351
351
352
+ std::tuple<at::Tensor,optional<int64_t >> group_norm_backward_no_weight_bias_batch_rule (
353
+ const at::Tensor & grad_out, optional<int64_t > grad_out_bdim,
354
+ const at::Tensor & input, optional<int64_t > input_bdim,
355
+ const at::Tensor & mean, optional<int64_t > mean_bdim,
356
+ const at::Tensor & rstd, optional<int64_t > rstd_bdim,
357
+ int64_t N, int64_t C, int64_t HxW, int64_t group) {
358
+ auto grad_out_ = moveBatchDimToFront (grad_out, grad_out_bdim);
359
+ auto input_ = moveBatchDimToFront (input, input_bdim);
360
+ auto mean_ = moveBatchDimToFront (mean, mean_bdim);
361
+ auto rstd_ = moveBatchDimToFront (rstd, rstd_bdim);
362
+
363
+ const auto bdim_size = get_bdim_size2 (grad_out, grad_out_bdim, input, input_bdim);
364
+ grad_out_ = ensure_has_bdim (grad_out, grad_out_bdim.has_value (), bdim_size);
365
+ input_ = ensure_has_bdim (input_, input_bdim.has_value (), bdim_size);
366
+ mean_ = ensure_has_bdim (mean_, mean_bdim.has_value (), bdim_size);
367
+ rstd_ = ensure_has_bdim (rstd_, rstd_bdim.has_value (), bdim_size);
368
+
369
+ grad_out_ = reshape_dim_into (0 , 0 , grad_out_); // [B0 * N, C, *]
370
+ input_ = reshape_dim_into (0 , 0 , input_); // [B0 * N, C, *]
371
+ mean_ = reshape_dim_into (0 , 0 , mean_); // [B0 * N, G]
372
+ rstd_ = reshape_dim_into (0 , 0 , rstd_); // [B0 * N, G]
373
+
374
+ const auto result = native_group_norm_backward (
375
+ grad_out_.contiguous (),
376
+ input_.contiguous (),
377
+ mean_.contiguous (),
378
+ rstd_.contiguous (),
379
+ nullopt, N * bdim_size, C, HxW, group, {true , false , false });
380
+ auto result0 = std::get<0 >(result);
381
+ result0 = reshape_dim_outof (0 , bdim_size, result0);
382
+ return std::make_tuple (result0, 0 );
383
+ }
384
+
352
385
std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing (
353
386
const Tensor & grad_out, const Tensor & input, const Tensor & mean,
354
387
const Tensor & rstd, const c10::optional<Tensor> & weight_opt,
@@ -368,9 +401,6 @@ std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
368
401
return at::native_group_norm_backward (grad_out, input, mean, rstd, weight_opt, N, C, HxW, group, output_mask);
369
402
}
370
403
371
- Tensor grad_out_value;
372
- optional<int64_t > grad_out_bdim;
373
- std::tie (grad_out_value, grad_out_bdim) = unwrapTensorAtLevel (grad_out, cur_level);
374
404
Tensor input_value;
375
405
optional<int64_t > input_bdim;
376
406
std::tie (input_value, input_bdim) = unwrapTensorAtLevel (input, cur_level);
@@ -410,32 +440,16 @@ std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
410
440
optional<int64_t > grad_normalized_input_bdim;
411
441
std::tie (grad_normalized_input_value, grad_normalized_input_bdim) =
412
442
unwrapTensorAtLevel (grad_normalized_input, cur_level);
413
- auto grad_out_ = moveBatchDimToFront (grad_normalized_input_value, grad_normalized_input_bdim);
414
- auto input_ = moveBatchDimToFront (input_value, input_bdim);
415
- auto mean_ = moveBatchDimToFront (mean_value, mean_bdim);
416
- auto rstd_ = moveBatchDimToFront (rstd_value, rstd_bdim);
417
-
418
- const auto bdim_size = get_bdim_size3 (grad_out_, grad_out_bdim, input_, input_bdim, weight, weight_bdim);
419
- grad_out_ = ensure_has_bdim (grad_out_, grad_out_bdim.has_value (), bdim_size);
420
- input_ = ensure_has_bdim (input_, input_bdim.has_value (), bdim_size);
421
- mean_ = ensure_has_bdim (mean_, mean_bdim.has_value (), bdim_size);
422
- rstd_ = ensure_has_bdim (rstd_, rstd_bdim.has_value (), bdim_size);
423
-
424
- grad_out_ = reshape_dim_into (0 , 0 , grad_out_); // [B0 * N, C, *]
425
- input_ = reshape_dim_into (0 , 0 , input_); // [B0 * N, C, *]
426
- mean_ = reshape_dim_into (0 , 0 , mean_); // [B0 * N, G]
427
- rstd_ = reshape_dim_into (0 , 0 , rstd_); // [B0 * N, G]
428
443
429
444
c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
430
- const auto result = native_group_norm_backward (
431
- grad_out_,
432
- input_,
433
- mean_,
434
- rstd_,
435
- nullopt, N * bdim_size, C, HxW, group, {true , false , false });
436
- auto result0 = std::get<0 >(result);
437
- result0 = reshape_dim_outof (0 , bdim_size, result0);
438
- grad_input = makeBatched (result0, 0 , cur_level);
445
+ const auto res = group_norm_backward_no_weight_bias_batch_rule (
446
+ grad_normalized_input_value, grad_normalized_input_bdim,
447
+ input_value, input_bdim,
448
+ mean_value, mean_bdim,
449
+ rstd_value, rstd_bdim,
450
+ N, C, HxW, group
451
+ );
452
+ grad_input = makeBatched (std::get<0 >(res), std::get<1 >(res), cur_level);
439
453
}
440
454
return std::make_tuple (grad_input, grad_weight, grad_bias);
441
455
}
0 commit comments