Skip to content

Commit 915aecb

Browse files
author
samdow
committed
batch norm forward over reverse coverage with decomposition
1 parent 2b16530 commit 915aecb

File tree

4 files changed

+97
-10
lines changed

4 files changed

+97
-10
lines changed

functorch/_src/decompositions.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
4747
return min - torch.log1p(z), buffer
4848

4949

50+
def recompute_mean_var(input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool):
51+
# for most norm decompositions, it will be the same as the core version except for here.
52+
# We recompute the mean and variance so that they track gradients through input
53+
54+
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
55+
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
56+
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
57+
eps = eps.detach()
58+
rstd = 1 / torch.sqrt(var + eps)
59+
return mean, rstd
60+
5061
@register_decomposition_for_jvp(aten.native_layer_norm_backward)
5162
def native_layer_norm_backward(
5263
grad_out: Tensor,
@@ -80,13 +91,7 @@ def native_layer_norm_backward(
8091
input.new_zeros(input_shape[axis:]),
8192
)
8293

83-
# this is exactly the same as the other decomposition except for here. We recompute the mean and variance
84-
# so that they track gradients through input
85-
mean_ = torch.mean(input, dim=inner_dim_indices, keepdim=True)
86-
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=True)
87-
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
88-
eps = eps.detach()
89-
rstd_ = 1 / torch.sqrt(var + eps)
94+
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
9095

9196
x_hat = (input - mean_) * rstd_
9297
if weight is not None:
@@ -128,3 +133,84 @@ def native_layer_norm_backward(
128133
d_bias = torch.zeros(()) # should be None but doesn't work with vjp
129134

130135
return (d_input, d_weight, d_bias)
136+
137+
138+
def prod(x: List[int]):
139+
r = 1
140+
for i in x:
141+
r *= i
142+
return r
143+
144+
145+
@register_decomposition(aten.native_batch_norm_backward) # @register_decomposition_for_jvp after in core
146+
def native_batch_norm_backward(
147+
grad_out: Tensor,
148+
input: Tensor,
149+
weight: Optional[Tensor],
150+
running_mean: Optional[Tensor],
151+
running_var: Optional[Tensor],
152+
save_mean: Optional[Tensor],
153+
save_invstd: Optional[Tensor],
154+
train: bool,
155+
eps: float,
156+
output_mask: List[bool],
157+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
158+
input_shape = input.shape
159+
input_rank = input.dim()
160+
assert input_rank >= 2, "rank of the input must be at least 2"
161+
162+
axis = 1
163+
num_features = prod(input_shape) / input_shape[axis]
164+
mean = save_mean
165+
invstd = save_invstd
166+
if train:
167+
assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required"
168+
169+
reduciton_dims = [0] + list(range(2, input.dim()))
170+
assert invstd is not None # for typing
171+
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
172+
else:
173+
assert running_mean is not None and running_var is not None
174+
mean = running_mean
175+
invstd = torch.rsqrt(running_var + eps)
176+
177+
broadcast_mask = [1] * input_rank
178+
broadcast_mask[axis] = input_shape[axis]
179+
180+
reduction_axes: List[int] = []
181+
for i in range(input_rank):
182+
if i != axis:
183+
reduction_axes.append(i)
184+
185+
mean = torch.reshape(mean, broadcast_mask)
186+
norm = 1.0 / num_features
187+
grad_output_sum = torch.sum(grad_out, reduction_axes)
188+
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
189+
190+
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
191+
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
192+
193+
if weight is None:
194+
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
195+
else:
196+
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
197+
198+
if train:
199+
proj = (input - mean) * proj_scale
200+
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
201+
else:
202+
grad_input = grad_out * grad_scale
203+
204+
if output_mask[1]:
205+
grad_weight = dot_p * invstd
206+
elif weight is not None:
207+
grad_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
208+
else:
209+
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
210+
211+
if output_mask[2]:
212+
grad_bias = grad_output_sum
213+
else:
214+
grad_bias = torch.zeros_like(grad_output_sum) # should be None but doesn't work with vjp
215+
216+
return (grad_input, grad_weight, grad_bias)

functorch/_src/eager_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,5 +1339,6 @@ def _register_python_decomposition_vmap(decomp):
13391339
_register_jit_decomposition(torch.ops.aten._softmax_backward_data.default)
13401340
_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default)
13411341
_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default)
1342+
_register_jit_decomposition(torch.ops.aten.native_batch_norm_backward.default, use_python=True)
13421343
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
13431344
_register_python_decomposition_vmap(torch.ops.aten.addr.default)

functorch/csrc/DynamicLayer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
502502
OP_DECOMPOSE(log_sigmoid);
503503
JVP_DECOMP(log_sigmoid_forward);
504504
JVP_DECOMP(native_layer_norm_backward);
505+
JVP_DECOMP(native_batch_norm_backward);
505506
}
506507

507508

test/test_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,8 +1146,6 @@ def get_vjp(cotangents, *primals):
11461146
xfail('logdet', ''),
11471147
xfail('nanmean', ''),
11481148
xfail('nansum', ''),
1149-
xfail('nn.functional.batch_norm', ''),
1150-
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
11511149
xfail('nn.functional.embedding'),
11521150
xfail('nn.functional.embedding', 'functorch'),
11531151
xfail('nn.functional.embedding_bag', ''),
@@ -1246,7 +1244,8 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
12461244
'softmax',
12471245
'log_softmax',
12481246
'nn.functional.cross_entropy',
1249-
'nn.functional.layer_norm'
1247+
'nn.functional.layer_norm',
1248+
'nn.functional.batch_norm',
12501249
}
12511250
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
12521251
self.assertFalse(op.supports_fwgrad_bwgrad,

0 commit comments

Comments
 (0)