Skip to content

Commit 08f09d9

Browse files
ezyangpytorchmergebot
authored andcommitted
Ensure rms_norm decomp generates add.Scalar for pattern match BC (pytorch#165437)
Summary: Apparently if I just do `tensor + eps` this turns into add.Tensor, which is bad because the constant Tensor ends up getting hoisted into an input, which is a bozo thing to do. Just make sure it's exactly compatible. Test Plan: ``` buck run 'fbcode//mode/opt' fbcode//bolt/nn/executorch/backends/tests:qnn_test_ar1g1 bolt.nn.executorch.backends.tests.qnn_test_ar1g1.QnnTestAR1G1.test_RMSNorm ``` Reviewed By: tugsbayasgalan Differential Revision: D84613184 Pull Request resolved: pytorch#165437 Approved by: https://github.com/tugsbayasgalan
1 parent 74acf92 commit 08f09d9

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torch/_decomp/decompositions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1783,7 +1783,10 @@ def _fused_rms_norm(
17831783

17841784
rqrst_input = torch.rsqrt(
17851785
# NB: don't inplace here, will violate functional IR invariant
1786-
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val)
1786+
# NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp
1787+
torch.ops.aten.add.Scalar(
1788+
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val
1789+
)
17871790
)
17881791

17891792
upcasted_result = upcasted_input.mul(rqrst_input)

0 commit comments

Comments
 (0)