Skip to content

Commit 2fe6c7e

Browse files
authored
Fix the decomposition of native_layer_norm_backward operation. (#674)
The return type for native_layer_norm_backward is a tuple of optional tensors as the output_mask decides whether or not the gradient of a particular input is to be calculated or not. Signed-Off-By: Prateek Gupta <[email protected]>
1 parent 090798d commit 2fe6c7e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

functorch/_src/decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int =
441441

442442

443443
@register_decomposition(aten.native_layer_norm_backward)
444-
def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Tensor, Tensor, Tensor]:
444+
def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool]) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
445445
input_shape = input.shape
446446
input_ndim = input.dim()
447447

test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,7 @@ def get_names(inpt):
15551555
f.write(f'{op}\n')
15561556

15571557
def test_decompositions_torchscriptable(self, device):
1558-
skip_list = [torch.ops.aten.native_layer_norm_backward.default]
1558+
skip_list = []
15591559
for op, decomposition in decomposition_table.items():
15601560
if op in skip_list:
15611561
continue

0 commit comments

Comments
 (0)