Skip to content

Commit 5a6113f

Browse files
Arm backend: Add TOSA dialect op for MATMUL (#14694)
Adds TOSA backend dialect op for MATMUL and associating pass to rewrite edge.aten.bmm to tosa.MATMUL. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 91f1769 commit 5a6113f

File tree

13 files changed

+258
-148
lines changed

13 files changed

+258
-148
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
ReplaceScalarWithTensorArgPassTOSABI,
9292
ReplaceScalarWithTensorArgPassTOSAMI,
9393
)
94+
from .rewrite_matmul import RewriteMatmulPass # noqa
9495
from .rewrite_upsample import RewriteUpsamplePass # noqa
9596
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
9697
from .size_adjust_input_pass import SizeAdjustInputPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
ReplaceScalarWithTensorArgPassTOSABI,
9393
ReplaceScalarWithTensorArgPassTOSAMI,
9494
RetraceFoldedDtypesPass,
95+
RewriteMatmulPass,
9596
RewriteUpsamplePass,
9697
ScalarsToAttributePass,
9798
SizeAdjustInputPass,
@@ -211,6 +212,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
211212
self.add_pass(RewriteUpsamplePass(exported_program))
212213
self.add_pass(AddBiasPass(exported_program))
213214

215+
self.add_pass(RewriteMatmulPass(exported_program))
214216
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
215217
self.add_pass(ToTosaMemoryFormatPass(exported_program))
216218
self.add_pass(RemoveNoopPass())
@@ -297,6 +299,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
297299
self.add_pass(RewriteUpsamplePass(exported_program))
298300
self.add_pass(AddBiasPass(exported_program))
299301
self.add_pass(InsertTableOpsPass(exported_program))
302+
self.add_pass(RewriteMatmulPass(exported_program))
300303
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
301304
self.add_pass(ToTosaMemoryFormatPass(exported_program))
302305
self.add_pass(RemoveNoopPass())

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def call(self, graph_module):
114114
if node.op != "call_function":
115115
continue
116116
if node.target in [
117+
exir_ops.backend.tosa.MATMUL.default,
117118
exir_ops.backend.tosa.RESCALE.default,
118119
exir_ops.backend.tosa.RESIZE.default,
119120
exir_ops.backend.tosa.TABLE.default,
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
14+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
15+
get_input_qparams,
16+
get_output_qparams,
17+
)
18+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
19+
from executorch.exir.dialects._ops import ops as exir_ops
20+
from executorch.exir.pass_base import ExportPass, PassResult
21+
22+
23+
class RewriteMatmulPass(ArmPass):
24+
"""Rewrites aten.bmm to tosa.MATMUL and inserts a tosa.RESCALE op if needed."""
25+
26+
_passes_required_after: Set[Type[ExportPass]] = set()
27+
28+
def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype):
29+
input_qparams = get_input_qparams(node)
30+
output_qparams = get_output_qparams(node)[0]
31+
scale = (
32+
input_qparams[0].get_scale_per_tensor()
33+
* input_qparams[1].get_scale_per_tensor()
34+
) / output_qparams.get_scale_per_tensor()
35+
36+
with graph_module.graph.inserting_after(tosa_matmul_node):
37+
# If the input is int8, we need to cast the output to int32
38+
rescale_node = create_node(
39+
graph_module.graph,
40+
op_target=exir_ops.backend.tosa.RESCALE.default,
41+
from_node=tosa_matmul_node,
42+
)
43+
tosa_matmul_node.replace_all_uses_with(rescale_node)
44+
rescale_node.args = (
45+
tosa_matmul_node,
46+
dtype,
47+
scale,
48+
0,
49+
output_qparams.get_zp_per_tensor(),
50+
)
51+
52+
def call(self, graph_module):
53+
modified = False
54+
for node in graph_module.graph.nodes:
55+
if (
56+
node.op != "call_function"
57+
or node.target != exir_ops.edge.aten.bmm.default
58+
):
59+
continue
60+
modified = True
61+
62+
x1, x2 = node.args
63+
tosa_matmul_target = exir_ops.backend.tosa.MATMUL.default
64+
with graph_module.graph.inserting_before(node):
65+
tosa_matmul_node = create_node(
66+
graph_module.graph,
67+
op_target=tosa_matmul_target,
68+
args=(x1, x2),
69+
kwargs={},
70+
from_node=node,
71+
)
72+
node.replace_all_uses_with(tosa_matmul_node)
73+
graph_module.graph.erase_node(node)
74+
75+
x1_fake_tensor = get_first_fake_tensor(x1)
76+
x2_fake_tensor = get_first_fake_tensor(x2)
77+
output_fake_tensor = tosa_matmul_target(x1_fake_tensor, x2_fake_tensor)
78+
node_output_fake_tensor = get_first_fake_tensor(node)
79+
if (
80+
output_fake_tensor.dtype == torch.int32
81+
and node_output_fake_tensor.dtype in (torch.int8, torch.int16)
82+
):
83+
self._insert_output_rescale(
84+
graph_module,
85+
node,
86+
tosa_matmul_node,
87+
dtype=node_output_fake_tensor.dtype,
88+
)
89+
if x1_fake_tensor.dtype == torch.int16:
90+
tosa_matmul_node.meta[TosaSpecialDtype.meta_key()] = (
91+
TosaSpecialDtype.INT48
92+
)
93+
94+
if modified:
95+
graph_module.recompile()
96+
graph_module = super().call(graph_module).graph_module
97+
return PassResult(graph_module, modified)

backends/arm/operators/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
op_any,
1515
op_avg_pool2d,
1616
op_bitwise_not,
17-
op_bmm,
1817
op_cat,
1918
op_ceil,
2019
op_clamp,
@@ -42,19 +41,20 @@
4241
op_pow,
4342
op_reciprocal,
4443
op_repeat,
45-
op_rescale,
46-
op_resize,
4744
op_rshift_tensor,
4845
op_rsqrt,
4946
op_sigmoid,
5047
op_sin,
5148
op_slice,
5249
op_sub,
5350
op_sum,
54-
op_table,
5551
op_tanh,
5652
op_to_dim_order_copy,
57-
op_transpose,
53+
op_tosa_matmul,
54+
op_tosa_rescale,
55+
op_tosa_resize,
56+
op_tosa_table,
57+
op_tosa_transpose,
5858
op_view,
5959
op_where,
6060
ops_binary,

backends/arm/operators/op_bmm.py

Lines changed: 0 additions & 143 deletions
This file was deleted.

0 commit comments

Comments
 (0)