Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 3b9e79a

Browse files
committed
fix None consistency
1 parent ca362fd commit 3b9e79a

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

functorch/_src/decompositions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,23 +477,23 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
477477
if output_mask[0]:
478478
d_input = aten.mul(aten.div(rstd, N), inner)
479479
else:
480-
d_input = aten.new_empty(input, (0,))
480+
d_input = None
481481

482482
if output_mask[1] and weight is not None:
483483
if len(outer_dim_indices) > 0:
484484
d_weight = aten.sum(aten.mul(grad_out, x_hat), outer_dim_indices, False)
485485
else:
486486
d_weight = aten.mul(grad_out, x_hat)
487487
else:
488-
d_weight = aten.new_empty(input, (0,))
488+
d_weight = None
489489

490490
if output_mask[2] and bias is not None:
491491
if len(outer_dim_indices) > 0:
492492
d_bias = aten.sum(grad_out, outer_dim_indices, False)
493493
else:
494494
d_bias = grad_out
495495
else:
496-
d_bias = aten.new_empty(input, (0,))
496+
d_bias = None
497497
return (d_input, d_weight, d_bias)
498498

499499
# @register_decomposition(aten.addmm)

test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1313,7 +1313,7 @@ def get_names(inpt):
13131313
f.write(f'{op}\n')
13141314

13151315
def test_decompositions_torchscriptable(self, device):
1316-
skip_list = []
1316+
skip_list = [torch.ops.aten.native_layer_norm_backward.default]
13171317
for op, decomposition in decomposition_table.items():
13181318
if op in skip_list:
13191319
continue

0 commit comments

Comments
 (0)