Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 481fada

Browse files
authored
Add decomposition for aten.native_layer_norm_backward op. (#525)
Signed-Off-By: Prateek Gupta <[email protected]>
1 parent 128764e commit 481fada

File tree

2 files changed

+73
-3
lines changed

2 files changed

+73
-3
lines changed

functorch/_src/decompositions.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def native_layer_norm(input: Tensor, normalized_shape: List[int], weight: Option
381381
if M > 0:
382382
input_reshaped = input.view(1, M, -1)
383383
else:
384-
return (input, aten.new_empty(input, (0,)), aten.new_empty(input, (0,)))
384+
return (input, aten.new_zeros(input, (0,)), aten.new_zeros(input, (0,)))
385385

386386
# Unlike Batch Normalization, which applies scalar scale and bias for each
387387
# entire channel/plane with the affine option, Layer Normalization applies
@@ -439,6 +439,72 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
439439
return out
440440
return beta * self + out
441441

442+
@register_decomposition(aten.native_layer_norm_backward)
443+
def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]:
444+
input_shape = input.shape
445+
input_ndim = input.dim()
446+
447+
axis = input_ndim - len(normalized_shape)
448+
inner_dims = input_shape[axis:]
449+
outer_dims = input_shape[:axis]
450+
inner_dim_indices = []
451+
outer_dim_indices = []
452+
for i in range(input_ndim):
453+
if(i >= axis):
454+
inner_dim_indices.append(i)
455+
else:
456+
outer_dim_indices.append(i)
457+
458+
N = float(prod(inner_dims))
459+
M = prod(outer_dims)
460+
if M <= 0 or N <= 0.0:
461+
return (aten.new_empty(input, input_shape), aten.new_zeros(input[axis:], input_shape[axis:]), aten.new_zeros(input[axis:], input_shape[axis:]))
462+
463+
x_hat = aten.mul(aten.sub(input, mean), rstd)
464+
if weight is not None:
465+
grad_x_hat = aten.mul(grad_out, weight)
466+
else:
467+
grad_x_hat = grad_out
468+
a = aten.mul(grad_x_hat, N)
469+
b = aten.sum(grad_x_hat, inner_dim_indices, True)
470+
c1 = aten.mul(grad_x_hat, x_hat)
471+
c2 = aten.sum(c1, inner_dim_indices, True)
472+
c3 = aten.mul(x_hat, c2)
473+
474+
inner = aten.sub(aten.sub(a, b), c3)
475+
476+
if output_mask[0]:
477+
d_input = aten.mul(aten.div(rstd, N), inner)
478+
else:
479+
d_input = None
480+
481+
if output_mask[1] and weight is not None:
482+
if len(outer_dim_indices) > 0:
483+
d_weight = aten.sum(aten.mul(grad_out, x_hat), outer_dim_indices, False)
484+
else:
485+
d_weight = aten.mul(grad_out, x_hat)
486+
else:
487+
d_weight = None
488+
489+
if output_mask[2] and bias is not None:
490+
if len(outer_dim_indices) > 0:
491+
d_bias = aten.sum(grad_out, outer_dim_indices, False)
492+
else:
493+
d_bias = grad_out
494+
else:
495+
d_bias = None
496+
return (d_input, d_weight, d_bias)
497+
498+
# @register_decomposition(aten.addmm)
499+
# def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta=1, alpha=1):
500+
# if not self.is_floating_point():
501+
# beta = int(beta)
502+
# alpha = int(alpha)
503+
# out = alpha * aten.mm(mat1, mat2)
504+
# if beta == 0:
505+
# return out
506+
# return beta * self + out
507+
442508

443509
@register_decomposition(aten.clamp_min)
444510
def clamp_min(self: Tensor, min: float):

test/test_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,8 @@ def op_assert_equal(op, a, b, arg_string):
11091109
# Before adding an entry to this table, make sure your decomposition is right :)
11101110
tol_table = {
11111111
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
1112-
(torch.float32, aten.native_layer_norm.default): (1e-3, 1e-3),
1112+
(torch.float32, aten.native_layer_norm): (1e-3, 1e-3),
1113+
(torch.float32, aten.native_layer_norm_backward): (1e-3, 1e-3),
11131114
}
11141115
if (b.dtype, op) in tol_table:
11151116
rtol, atol = tol_table[(b.dtype, op)]
@@ -1231,6 +1232,9 @@ def call_op(func, map_fn, *args, **kwargs):
12311232
real_out = call_op(func, unwrap_tensor, *args, **kwargs)
12321233
assert(len(real_out) == len(decomp_out))
12331234
for orig, decomp, ref in zip(real_out, decomp_out, real_out_double):
1235+
if orig is None:
1236+
assert(decomp is None)
1237+
continue
12341238
orig = orig.to(dtype=TEST_DTYPE)
12351239
decomp = decomp.to(dtype=TEST_DTYPE)
12361240
if DO_RELATIVE_CHECK and ref.dtype.is_floating_point:
@@ -1309,7 +1313,7 @@ def get_names(inpt):
13091313
f.write(f'{op}\n')
13101314

13111315
def test_decompositions_torchscriptable(self, device):
1312-
skip_list = []
1316+
skip_list = [torch.ops.aten.native_layer_norm_backward]
13131317
for op, decomposition in decomposition_table.items():
13141318
if op in skip_list:
13151319
continue

0 commit comments

Comments
 (0)