Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 6b87dcd

Browse files
committed
fixed some minor nits with decomposition
1 parent 481fada commit 6b87dcd

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

functorch/_src/decompositions.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
439439
return out
440440
return beta * self + out
441441

442+
442443
@register_decomposition(aten.native_layer_norm_backward)
443444
def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]:
444445
input_shape = input.shape
@@ -447,18 +448,18 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
447448
axis = input_ndim - len(normalized_shape)
448449
inner_dims = input_shape[axis:]
449450
outer_dims = input_shape[:axis]
450-
inner_dim_indices = []
451-
outer_dim_indices = []
451+
inner_dim_indices: List[int] = []
452+
outer_dim_indices: List[int] = []
452453
for i in range(input_ndim):
453454
if(i >= axis):
454455
inner_dim_indices.append(i)
455456
else:
456457
outer_dim_indices.append(i)
457458

458-
N = float(prod(inner_dims))
459+
N = prod(inner_dims)
459460
M = prod(outer_dims)
460-
if M <= 0 or N <= 0.0:
461-
return (aten.new_empty(input, input_shape), aten.new_zeros(input[axis:], input_shape[axis:]), aten.new_zeros(input[axis:], input_shape[axis:]))
461+
if M <= 0 or N <= 0:
462+
return (aten.new_zeros(input, input_shape), aten.new_zeros(input, input_shape[axis:]), aten.new_zeros(input, input_shape[axis:]))
462463

463464
x_hat = aten.mul(aten.sub(input, mean), rstd)
464465
if weight is not None:
@@ -476,23 +477,23 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
476477
if output_mask[0]:
477478
d_input = aten.mul(aten.div(rstd, N), inner)
478479
else:
479-
d_input = None
480+
d_input = aten.new_empty(input, (0,))
480481

481482
if output_mask[1] and weight is not None:
482483
if len(outer_dim_indices) > 0:
483484
d_weight = aten.sum(aten.mul(grad_out, x_hat), outer_dim_indices, False)
484485
else:
485486
d_weight = aten.mul(grad_out, x_hat)
486487
else:
487-
d_weight = None
488+
d_weight = aten.new_empty(input, (0,))
488489

489490
if output_mask[2] and bias is not None:
490491
if len(outer_dim_indices) > 0:
491492
d_bias = aten.sum(grad_out, outer_dim_indices, False)
492493
else:
493494
d_bias = grad_out
494495
else:
495-
d_bias = None
496+
d_bias = aten.new_empty(input, (0,))
496497
return (d_input, d_weight, d_bias)
497498

498499
# @register_decomposition(aten.addmm)

test/test_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,8 +1109,8 @@ def op_assert_equal(op, a, b, arg_string):
11091109
# Before adding an entry to this table, make sure your decomposition is right :)
11101110
tol_table = {
11111111
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
1112-
(torch.float32, aten.native_layer_norm): (1e-3, 1e-3),
1113-
(torch.float32, aten.native_layer_norm_backward): (1e-3, 1e-3),
1112+
(torch.float32, aten.native_layer_norm.default): (1e-3, 1e-3),
1113+
(torch.float32, aten.native_layer_norm_backward.default): (1e-3, 1e-3),
11141114
}
11151115
if (b.dtype, op) in tol_table:
11161116
rtol, atol = tol_table[(b.dtype, op)]
@@ -1313,7 +1313,7 @@ def get_names(inpt):
13131313
f.write(f'{op}\n')
13141314

13151315
def test_decompositions_torchscriptable(self, device):
1316-
skip_list = [torch.ops.aten.native_layer_norm_backward]
1316+
skip_list = []
13171317
for op, decomposition in decomposition_table.items():
13181318
if op in skip_list:
13191319
continue

0 commit comments

Comments
 (0)