Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit c5cffb8

Browse files
author
Samantha Andow
authored
fix group norm, add scaffolding for autograd.grad tests (#630)
1 parent c5ce4d0 commit c5cffb8

File tree

2 files changed

+86
-52
lines changed

2 files changed

+86
-52
lines changed

functorch/csrc/BatchRulesNorm.cpp

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,39 @@ std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
349349
return std::make_tuple(result0, mean, rstd);
350350
}
351351

352+
std::tuple<at::Tensor,optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule(
353+
const at::Tensor & grad_out, optional<int64_t> grad_out_bdim,
354+
const at::Tensor & input, optional<int64_t> input_bdim,
355+
const at::Tensor & mean, optional<int64_t> mean_bdim,
356+
const at::Tensor & rstd, optional<int64_t> rstd_bdim,
357+
int64_t N, int64_t C, int64_t HxW, int64_t group) {
358+
auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
359+
auto input_ = moveBatchDimToFront(input, input_bdim);
360+
auto mean_ = moveBatchDimToFront(mean, mean_bdim);
361+
auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim);
362+
363+
const auto bdim_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
364+
grad_out_ = ensure_has_bdim(grad_out, grad_out_bdim.has_value(), bdim_size);
365+
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
366+
mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
367+
rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);
368+
369+
grad_out_ = reshape_dim_into(0, 0, grad_out_); // [B0 * N, C, *]
370+
input_ = reshape_dim_into(0, 0, input_); // [B0 * N, C, *]
371+
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
372+
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]
373+
374+
const auto result = native_group_norm_backward(
375+
grad_out_.contiguous(),
376+
input_.contiguous(),
377+
mean_.contiguous(),
378+
rstd_.contiguous(),
379+
nullopt, N * bdim_size, C, HxW, group, {true, false, false});
380+
auto result0 = std::get<0>(result);
381+
result0 = reshape_dim_outof(0, bdim_size, result0);
382+
return std::make_tuple(result0, 0);
383+
}
384+
352385
std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
353386
const Tensor & grad_out, const Tensor & input, const Tensor & mean,
354387
const Tensor & rstd, const c10::optional<Tensor> & weight_opt,
@@ -368,9 +401,6 @@ std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
368401
return at::native_group_norm_backward(grad_out, input, mean, rstd, weight_opt, N, C, HxW, group, output_mask);
369402
}
370403

371-
Tensor grad_out_value;
372-
optional<int64_t> grad_out_bdim;
373-
std::tie(grad_out_value, grad_out_bdim) = unwrapTensorAtLevel(grad_out, cur_level);
374404
Tensor input_value;
375405
optional<int64_t> input_bdim;
376406
std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
@@ -410,32 +440,16 @@ std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
410440
optional<int64_t> grad_normalized_input_bdim;
411441
std::tie(grad_normalized_input_value, grad_normalized_input_bdim) =
412442
unwrapTensorAtLevel(grad_normalized_input, cur_level);
413-
auto grad_out_ = moveBatchDimToFront(grad_normalized_input_value, grad_normalized_input_bdim);
414-
auto input_ = moveBatchDimToFront(input_value, input_bdim);
415-
auto mean_ = moveBatchDimToFront(mean_value, mean_bdim);
416-
auto rstd_ = moveBatchDimToFront(rstd_value, rstd_bdim);
417-
418-
const auto bdim_size = get_bdim_size3(grad_out_, grad_out_bdim, input_, input_bdim, weight, weight_bdim);
419-
grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size);
420-
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
421-
mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
422-
rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);
423-
424-
grad_out_ = reshape_dim_into(0, 0, grad_out_); // [B0 * N, C, *]
425-
input_ = reshape_dim_into(0, 0, input_); // [B0 * N, C, *]
426-
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
427-
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]
428443

429444
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
430-
const auto result = native_group_norm_backward(
431-
grad_out_,
432-
input_,
433-
mean_,
434-
rstd_,
435-
nullopt, N * bdim_size, C, HxW, group, {true, false, false});
436-
auto result0 = std::get<0>(result);
437-
result0 = reshape_dim_outof(0, bdim_size, result0);
438-
grad_input = makeBatched(result0, 0, cur_level);
445+
const auto res = group_norm_backward_no_weight_bias_batch_rule(
446+
grad_normalized_input_value, grad_normalized_input_bdim,
447+
input_value, input_bdim,
448+
mean_value, mean_bdim,
449+
rstd_value, rstd_bdim,
450+
N, C, HxW, group
451+
);
452+
grad_input = makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
439453
}
440454
return std::make_tuple(grad_input, grad_weight, grad_bias);
441455
}

test/test_ops.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors
88
import torch
99
from torch import Tensor
10-
import torch.nn.functional as F
1110
import functools
1211
import unittest
13-
import itertools
1412
from contextlib import contextmanager
1513
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1614
from torch.testing._internal.common_device_type import ops
@@ -29,7 +27,6 @@
2927
# tol2,
3028
opsToleranceOverride,
3129
check_vmap_fallback,
32-
loop,
3330
IS_FBCODE,
3431
)
3532
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
@@ -198,6 +195,35 @@ def wrapped(*args):
198195
return wrapped, tuple(flat_args + flat_cotangents)
199196

200197

198+
# returns a new function g(*args, *cotangents)
199+
# that computes vjps and (*args, cotangents) using torch.autograd.grad
200+
def get_autograd_fn_and_args_with_cotangents(f, sample, cotangents):
201+
args = tuple([sample.input] + list(sample.args))
202+
kwargs = sample.kwargs
203+
flat_args, args_spec = tree_flatten(args)
204+
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
205+
206+
@functools.wraps(f)
207+
def wrapped(*args):
208+
assert len(args) == len(flat_args) + len(flat_cotangents)
209+
actual_args = args[:len(flat_args)]
210+
cotangents = args[len(flat_args):]
211+
actual_args = tree_unflatten(actual_args, args_spec)
212+
cotangents = tree_unflatten(cotangents, cotangents_spec)
213+
214+
fn, primals = normalize_op_input_output3(f, actual_args, kwargs,
215+
flat_args,
216+
sample.output_process_fn_grad)
217+
out = fn(*primals)
218+
diff_wrt = tuple(primal for primal in primals if (primal.requires_grad or primal.grad_fn is not None))
219+
if diff_wrt:
220+
return torch.autograd.grad(out, diff_wrt, grad_outputs=cotangents)
221+
else:
222+
return (torch.ones(()),) # uuugh hack...this will need to be more generic
223+
224+
return wrapped, tuple(flat_args + flat_cotangents)
225+
226+
201227
# Returns a new function g(*args, *cotangents) that computes vjps and
202228
# sample (*args, *cotangents)
203229
def get_vjpfull_variant(f, sample):
@@ -1433,28 +1459,22 @@ def test_decompositions_torchscriptable(self, device):
14331459
continue
14341460
torch.jit.script(decomposition)
14351461

1436-
def test_group_norm_backward(self, device):
1437-
# group norm will hit the decomposable ``infinitely_differentiable_group_norm_backward`` when
1438-
# GradMode is on, which happens by default in the grad transform. This avoids that
1439-
def f(x, weight, bias, grad_out):
1440-
output = F.group_norm(x, 6, weight, bias)
1441-
inputs = []
1442-
for input in (x, weight, bias):
1443-
if input.requires_grad:
1444-
inputs.append(input)
1445-
return torch.autograd.grad(outputs=output, inputs=inputs, grad_outputs=grad_out)
1446-
1447-
B, N, C, H, W = 2, 3, 24, 5, 7
1448-
for (input_grad, weight_grad, bias_grad) in itertools.product((True, False), (True, False), (True, False)):
1449-
if not input_grad and not weight_grad and not bias_grad:
1450-
continue
1451-
x = torch.randn(N, C, H, W, device=device, requires_grad=input_grad)
1452-
weight = torch.randn(C, device=device, requires_grad=weight_grad)
1453-
bias = torch.randn(C, device=device, requires_grad=bias_grad)
1454-
grad_out = torch.randn(B, N, C, H, W, device=device)
1455-
loop_out = loop(f, (None, None, None, 0), 0, 2, x, weight, bias, grad_out)
1456-
batched_out = vmap(f, (None, None, None, 0), 0)(x, weight, bias, grad_out)
1457-
self.assertEqual(loop_out, batched_out)
1462+
@ops(filter(lambda op: op.name == "nn.functional.group_norm", functorch_lagging_op_db + additional_op_db),
1463+
allowed_dtypes=(torch.float32, torch.double)) # TODO: generalize
1464+
def test_group_norm_backward(self, device, dtype, op):
1465+
# hacky, only works since no group norm inputs can be scalars
1466+
def was_skipped_from_batched_tensors(batched_out, batch_size):
1467+
return batched_out.shape == (batch_size,) and all(tuple(e == 1 for e in batched_out))
1468+
1469+
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
1470+
1471+
for sample_input in sample_inputs:
1472+
cotangents = get_sample_cotangents(op, sample_input)
1473+
f, args = get_autograd_fn_and_args_with_cotangents(op, sample_input, cotangents)
1474+
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}, opinfo=op):
1475+
if all(was_skipped_from_batched_tensors(bo, lo.shape[0]) for (bo, lo) in zip(batched_out, loop_out)):
1476+
continue # we weren't able to use the batched tensor in autograd.grad
1477+
self.assertEqual(loop_out, batched_out)
14581478

14591479

14601480
only_for = ("cpu", "cuda")

0 commit comments

Comments
 (0)