Skip to content

Commit dda2705

Browse files
authored
Arm backend: Decompose sub/add with alpha!=1 (#14932)
This was previously not supported, causing crashes in quantization, and incorrect output in floating point. Signed-off-by: Erik Lundell <[email protected]>
1 parent 418c584 commit dda2705

File tree

5 files changed

+125
-1
lines changed

5 files changed

+125
-1
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .convert_to_clamp import ConvertToClampPass # noqa
2828
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
2929
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
30+
from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa
3031
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
3132
from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa
3233
from .decompose_asinh_pass import DecomposeAsinhPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
DecomposeAcoshPass,
3737
DecomposeAdaptiveAvgPool2dPass,
3838
DecomposeAddmmPass,
39+
DecomposeAddSubAlphaPass,
3940
DecomposeAsinAndAcosPass,
4041
DecomposeAsinhPass,
4142
DecomposeAtanhPass,
@@ -262,6 +263,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
262263
)
263264
self.add_pass(DecomposeNotEqualPass())
264265
self.add_pass(DecomposeDivPass())
266+
self.add_pass(DecomposeAddSubAlphaPass())
265267
self.add_pass(DecomposeSoftmaxPass())
266268
self.add_pass(DecomposeGeluPass())
267269
self.add_pass(ConvertFullLikeToFullPass())
@@ -334,6 +336,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
334336
self.add_pass(DecomposeSignPass())
335337
self.add_pass(DecomposeAddmmPass())
336338
self.add_pass(DecomposeDivTensorModePass())
339+
self.add_pass(DecomposeAddSubAlphaPass())
337340
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
338341
self.add_pass(ScalarsToAttributePass())
339342
self.add_pass(DecomposeGroupNormPass())
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import numbers
9+
from typing import Set, Type
10+
11+
import torch
12+
from executorch.backends.arm._passes import ArmPass
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass
15+
16+
17+
_ADD_OPS = (
18+
exir_ops.edge.aten.add.Tensor,
19+
torch.ops.aten.add.Tensor,
20+
)
21+
22+
_SUB_OPS = (
23+
exir_ops.edge.aten.sub.Tensor,
24+
torch.ops.aten.sub.Tensor,
25+
)
26+
27+
28+
def _get_ops(op):
29+
if op in _ADD_OPS:
30+
if op is exir_ops.edge.aten.add.Tensor:
31+
return (
32+
exir_ops.edge.aten.mul.Tensor,
33+
exir_ops.edge.aten.full.default,
34+
exir_ops.edge.aten.add.Tensor,
35+
)
36+
return (
37+
torch.ops.aten.mul.Tensor,
38+
torch.ops.aten.full.default,
39+
torch.ops.aten.add.Tensor,
40+
)
41+
if op in _SUB_OPS:
42+
if op is exir_ops.edge.aten.sub.Tensor:
43+
return (
44+
exir_ops.edge.aten.mul.Tensor,
45+
exir_ops.edge.aten.full.default,
46+
exir_ops.edge.aten.sub.Tensor,
47+
)
48+
return (
49+
torch.ops.aten.mul.Tensor,
50+
torch.ops.aten.full.default,
51+
torch.ops.aten.sub.Tensor,
52+
)
53+
raise RuntimeError(f"Unsupported operator {op}")
54+
55+
56+
def _should_decompose(alpha) -> bool:
57+
if isinstance(alpha, numbers.Number):
58+
return alpha != 1
59+
return False
60+
61+
62+
class DecomposeAddSubAlphaPass(ArmPass):
63+
"""Rewrite add/sub with alpha into a mul followed by add/sub."""
64+
65+
_passes_required_after: Set[Type[ExportPass]] = set()
66+
67+
def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
68+
if op not in _ADD_OPS + _SUB_OPS:
69+
return super().call_operator(op, args, kwargs, meta, updated)
70+
71+
alpha = kwargs.get("alpha", 1)
72+
if not _should_decompose(alpha):
73+
return super().call_operator(op, args, kwargs, meta, updated)
74+
75+
mul_op, full_op, binary_op = _get_ops(op)
76+
lhs, rhs = args
77+
78+
alpha_full = super().call_operator(
79+
full_op, ((1,), float(alpha)), {}, meta, updated=True
80+
)
81+
scaled_rhs = super().call_operator(
82+
mul_op,
83+
(rhs, alpha_full),
84+
{},
85+
meta,
86+
updated=True,
87+
)
88+
return super().call_operator(
89+
binary_op,
90+
(lhs, scaled_rhs),
91+
{},
92+
meta,
93+
updated=True,
94+
)

backends/arm/test/ops/test_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
7878

7979
class Add3(torch.nn.Module):
8080
def forward(self, x: torch.Tensor, y: torch.Tensor):
81-
return x + y
81+
return torch.add(x, y, alpha=1.5)
8282

8383
test_data: list[input_t2] = {
8484
"3d_randn_diff_rank": lambda: (torch.randn(1, 4, 5), torch.randn(4, 1)),

backends/arm/test/ops/test_sub.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
7979
return x - y
8080

8181

82+
class SubAlpha(torch.nn.Module):
83+
def forward(self, x: torch.Tensor, y: torch.Tensor):
84+
return torch.sub(x, y, alpha=5)
85+
86+
8287
class SubTan(torch.nn.Module):
8388

8489
def forward(self, x: torch.Tensor, y: torch.Tensor):
@@ -115,6 +120,18 @@ def test_sub_tensor_tosa_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
115120
pipeline.run()
116121

117122

123+
@common.parametrize("test_data", sub_tan_test_data)
124+
def test_sub_tensor_tosa_FP_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]):
125+
"""Test Two-Operand Subtraction with alpha (TOSA FP)"""
126+
pipeline = TosaPipelineFP[input_t2](
127+
SubAlpha(),
128+
test_data(),
129+
aten_op,
130+
exir_op,
131+
)
132+
pipeline.run()
133+
134+
118135
@common.parametrize("test_data", sub_test_data)
119136
def test_sub_tensor_tosa_INT(test_data):
120137
"""Test Subtraction (TOSA INT)"""
@@ -138,6 +155,15 @@ def test_sub_tensor_tosa_INT_3(test_data: Tuple[torch.Tensor, torch.Tensor]):
138155
pipeline.run()
139156

140157

158+
@common.parametrize("test_data", sub_tan_test_data)
159+
def test_sub_tensor_tosa_INT_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]):
160+
"""Test Two-Operand Subtraction with alpha (TOSA INT)"""
161+
pipeline = TosaPipelineINT[input_t2](
162+
SubAlpha(), test_data(), aten_op, exir_op, qtol=0
163+
)
164+
pipeline.run()
165+
166+
141167
@common.parametrize("test_data", sub_test_data)
142168
@common.XfailIfNoCorstone300
143169
def test_sub_tensor_u55_INT(test_data):

0 commit comments

Comments
 (0)