Skip to content

Commit fab5687

Browse files
authored
Arm backend: Add addmm decomposition pass and test (#12668)
Decomposes addmm into matmul and add operators. Signed-off-by: Teo Bergkvist <[email protected]>
1 parent e31eb56 commit fab5687

File tree

5 files changed

+223
-0
lines changed

5 files changed

+223
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .convert_to_clamp import ConvertToClampPass # noqa
2525
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
2626
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
27+
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
2728
from .decompose_asin_pass import DecomposeAsinPass # noqa
2829
from .decompose_atan_pass import DecomposeAtanPass # noqa
2930
from .decompose_atanh_pass import DecomposeAtanhPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ConvertToClampPass,
3030
DecomposeAcoshPass,
3131
DecomposeAdaptiveAvgPool2dPass,
32+
DecomposeAddmmPass,
3233
DecomposeAsinPass,
3334
DecomposeAtanhPass,
3435
DecomposeAtanPass,
@@ -165,6 +166,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
165166
self.add_pass(DecomposeSqrtPass())
166167
self.add_pass(DecomposeAtanPass())
167168
self.add_pass(DecomposeAtanhPass())
169+
self.add_pass(DecomposeAddmmPass())
168170
self.add_pass(ConvertIntPowToMuls())
169171
self.add_pass(CastBoolToInt8Pass())
170172
self.add_pass(DecomposeSinhPass())
@@ -257,6 +259,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
257259
self.add_pass(DecomposeRoundPass())
258260
self.add_pass(CastBoolToInt8Pass())
259261
self.add_pass(DecomposeSignPass())
262+
self.add_pass(DecomposeAddmmPass())
260263
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
261264
self.add_pass(ScalarsToAttributePass())
262265
self.add_pass(DecomposeGroupNormPass())
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.backends.arm._passes import ArmPass
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
12+
# For MI case
13+
edge_addmm = exir_ops.edge.aten.addmm.default
14+
# For BI case
15+
aten_addmm = torch.ops.aten.addmm.default
16+
17+
18+
def get_ops(op):
19+
"""Returns the appropriate operator functions based on the input operator."""
20+
if op == edge_addmm:
21+
return (
22+
exir_ops.edge.aten.mm.default,
23+
exir_ops.edge.aten.mul.Scalar,
24+
exir_ops.edge.aten.add.Tensor,
25+
)
26+
elif op == aten_addmm:
27+
return (
28+
torch.ops.aten.mm.default,
29+
torch.ops.aten.mul.Scalar,
30+
torch.ops.aten.add.Tensor,
31+
)
32+
else:
33+
raise ValueError(f"Unsupported operator: {op}")
34+
35+
36+
class DecomposeAddmmPass(ArmPass):
37+
"""Decomposes the addmm operator into tensor multiplication and addition."""
38+
39+
def call_operator(self, op, args, kwargs, meta):
40+
if op not in [edge_addmm, aten_addmm]:
41+
return super().call_operator(op, args, kwargs, meta)
42+
43+
input, mat1, mat2 = args
44+
beta = kwargs.get("beta", 1.0)
45+
alpha = kwargs.get("alpha", 1.0)
46+
47+
mul_op, mul_scalar_op, add_op = get_ops(op)
48+
49+
mul = super().call_operator(mul_op, (mat1, mat2), {}, meta, updated=True)
50+
mul_alpha = super().call_operator(
51+
mul_scalar_op, (mul, alpha), {}, meta, updated=True
52+
)
53+
54+
input_beta = super().call_operator(
55+
mul_scalar_op, (input, beta), {}, meta, updated=True
56+
)
57+
58+
return super().call_operator(
59+
add_op, (mul_alpha, input_beta), {}, meta, updated=True
60+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def is_node_supported(
253253
exir_ops.edge.aten.sign.default,
254254
exir_ops.edge.aten.asin.default,
255255
exir_ops.edge.aten.atanh.default,
256+
exir_ops.edge.aten.addmm.default,
256257
]
257258

258259
return supported
@@ -293,6 +294,7 @@ def is_node_supported(
293294
exir_ops.edge.aten.div.Scalar: None,
294295
exir_ops.edge.aten.leaky_relu.default: None,
295296
exir_ops.edge.aten.round.default: None,
297+
exir_ops.edge.aten.addmm.default: None,
296298
}
297299

298300
if node.target in needs_decomp_dict:

backends/arm/test/ops/test_addmm.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
aten_op = "torch.ops.aten.addmm.default"
19+
20+
exir_op = "executorch_exir_dialects_edge__ops_aten__addmm_default"
21+
22+
input_t1 = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] # Input x1, x2, x3
23+
24+
25+
test_data_suite = {
26+
"basic": [
27+
torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
28+
torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
29+
torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
30+
1.0,
31+
1.0,
32+
],
33+
"zeros": [torch.zeros(2, 2), torch.zeros(2, 3), torch.zeros(3, 2), 1.0, 1.0],
34+
"beta_only": [
35+
torch.tensor([[10.0, 20.0], [30.0, 40.0]]),
36+
torch.randn(2, 3),
37+
torch.randn(3, 2),
38+
0.0,
39+
1.0,
40+
],
41+
"alpha_only": [
42+
torch.tensor([[10.0, 20.0], [30.0, 40.0]]),
43+
torch.randn(2, 3),
44+
torch.randn(3, 2),
45+
1.0,
46+
0.0,
47+
],
48+
"scaled": [
49+
torch.ones(2, 2),
50+
torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
51+
torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
52+
0.5,
53+
2.0,
54+
],
55+
"negative_scalars": [
56+
torch.tensor([[1.0, -1.0], [-1.0, 1.0]]),
57+
torch.tensor([[2.0, 0.0], [0.0, 2.0]]),
58+
torch.tensor([[1.0, 1.0], [1.0, 1.0]]),
59+
-1.0,
60+
-1.0,
61+
],
62+
"non_square": [torch.ones(3, 4), torch.rand(3, 2), torch.rand(2, 4), 1.0, 1.0],
63+
"large_values": [
64+
torch.full((2, 2), 1e6),
65+
torch.full((2, 3), 1e3),
66+
torch.full((3, 2), 1e3),
67+
1.0,
68+
1.0,
69+
],
70+
"small_values": [
71+
torch.full((2, 2), 1e-6),
72+
torch.full((2, 3), 1e-3),
73+
torch.full((3, 2), 1e-3),
74+
1.0,
75+
1.0,
76+
],
77+
"random": [torch.randn(4, 5), torch.randn(4, 3), torch.randn(3, 5), 1.0, 1.0],
78+
"broadcast_bias_row": [
79+
torch.randn(1, 2),
80+
torch.randn(3, 4),
81+
torch.randn(4, 2),
82+
1.0,
83+
1.0,
84+
],
85+
"row_bias": [
86+
torch.randn(3, 1),
87+
torch.randn(3, 4),
88+
torch.randn(4, 4),
89+
1.0,
90+
1.0,
91+
],
92+
"scalar_bias": [
93+
torch.tensor(2.0),
94+
torch.randn(5, 3),
95+
torch.randn(3, 6),
96+
1.0,
97+
1.0,
98+
],
99+
}
100+
101+
102+
class Addmm(torch.nn.Module):
103+
def forward(
104+
self,
105+
x1: torch.Tensor,
106+
x2: torch.Tensor,
107+
x3: torch.Tensor,
108+
alpha: float,
109+
beta: float,
110+
) -> torch.Tensor:
111+
return torch.addmm(x1, x2, x3, alpha=alpha, beta=beta)
112+
113+
114+
@common.parametrize("test_data", test_data_suite)
115+
def test_addmm_tosa_MI(test_data: Tuple):
116+
pipeline = TosaPipelineMI[input_t1](
117+
Addmm(),
118+
(*test_data,),
119+
aten_op=aten_op,
120+
exir_op=exir_op,
121+
)
122+
pipeline.run()
123+
124+
125+
@common.parametrize("test_data", test_data_suite)
126+
def test_addmm_tosa_BI(test_data: Tuple):
127+
pipeline = TosaPipelineBI[input_t1](
128+
Addmm(),
129+
(*test_data,),
130+
aten_op=[],
131+
exir_op=exir_op,
132+
)
133+
pipeline.run()
134+
135+
136+
@common.XfailIfNoCorstone300
137+
@common.parametrize("test_data", test_data_suite)
138+
def test_addmm_u55_BI(test_data: Tuple):
139+
pipeline = EthosU55PipelineBI[input_t1](
140+
Addmm(),
141+
(*test_data,),
142+
aten_ops=[],
143+
exir_ops=exir_op,
144+
)
145+
pipeline.run()
146+
147+
148+
@common.XfailIfNoCorstone320
149+
@common.parametrize("test_data", test_data_suite)
150+
def test_addmm_u85_BI(test_data: Tuple):
151+
pipeline = EthosU85PipelineBI[input_t1](
152+
Addmm(),
153+
(*test_data,),
154+
aten_ops=[],
155+
exir_ops=exir_op,
156+
)
157+
pipeline.run()

0 commit comments

Comments
 (0)