Skip to content

Commit effd1f2

Browse files
Arm backend: Add TOSA dialect op for MATMUL
Adds TOSA backend dialect op for MATMUL and associating pass to rewrite edge.aten.bmm to tosa.MATMUL. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I578e5f7333922e02402dabc24ef1b12adf383b18
1 parent f7c009e commit effd1f2

File tree

8 files changed

+172
-48
lines changed

8 files changed

+172
-48
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
ReplaceScalarWithTensorArgPassTOSABI,
9292
ReplaceScalarWithTensorArgPassTOSAMI,
9393
)
94+
from .rewrite_matmul import RewriteMatmulPass # noqa
9495
from .rewrite_upsample import RewriteUpsamplePass # noqa
9596
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
9697
from .size_adjust_input_pass import SizeAdjustInputPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
ReplaceScalarWithTensorArgPassTOSABI,
9292
ReplaceScalarWithTensorArgPassTOSAMI,
9393
RetraceFoldedDtypesPass,
94+
RewriteMatmulPass,
9495
RewriteUpsamplePass,
9596
ScalarsToAttributePass,
9697
SizeAdjustInputPass,
@@ -210,6 +211,8 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
210211
self.add_pass(RewriteUpsamplePass(exported_program))
211212
self.add_pass(AddBiasPass(exported_program))
212213

214+
self.add_pass(InsertTableOpsPass(exported_program))
215+
self.add_pass(RewriteMatmulPass(exported_program))
213216
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
214217
self.add_pass(ToTosaMemoryFormatPass(exported_program))
215218
self.add_pass(RemoveNoopPass())
@@ -295,6 +298,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
295298
self.add_pass(RewriteUpsamplePass(exported_program))
296299
self.add_pass(AddBiasPass(exported_program))
297300
self.add_pass(InsertTableOpsPass(exported_program))
301+
self.add_pass(RewriteMatmulPass(exported_program))
298302
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
299303
self.add_pass(ToTosaMemoryFormatPass(exported_program))
300304
self.add_pass(RemoveNoopPass())

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def call(self, graph_module):
114114
if node.op != "call_function":
115115
continue
116116
if node.target in [
117+
exir_ops.backend.tosa.MATMUL.default,
117118
exir_ops.backend.tosa.RESCALE.default,
118119
exir_ops.backend.tosa.RESIZE.default,
119120
exir_ops.backend.tosa.TABLE.default,
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
14+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
15+
get_input_qparams,
16+
get_output_qparams,
17+
)
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
from executorch.exir.pass_base import ExportPass, PassResult
20+
21+
22+
class RewriteMatmulPass(ArmPass):
23+
"""Rewrites aten.bmm to tosa.MATMUL and inserts a tosa.RESCALE op if needed."""
24+
25+
_passes_required_after: Set[Type[ExportPass]] = set()
26+
27+
def _insert_output_rescale(self, graph_module, node, tosa_matmul_node):
28+
input_qparams = get_input_qparams(node)
29+
output_qparams = get_output_qparams(node)[0]
30+
scale = (
31+
input_qparams[0].get_scale_per_tensor()
32+
* input_qparams[1].get_scale_per_tensor()
33+
) / output_qparams.get_scale_per_tensor()
34+
35+
with graph_module.graph.inserting_after(tosa_matmul_node):
36+
# If the input is int8, we need to cast the output to int32
37+
rescale_node = create_node(
38+
graph_module.graph,
39+
op_target=exir_ops.backend.tosa.RESCALE.default,
40+
from_node=tosa_matmul_node,
41+
)
42+
tosa_matmul_node.replace_all_uses_with(rescale_node)
43+
rescale_node.args = (
44+
tosa_matmul_node,
45+
torch.int8,
46+
scale,
47+
0,
48+
output_qparams.get_zp_per_tensor(),
49+
)
50+
51+
def call(self, graph_module):
52+
modified = False
53+
for node in graph_module.graph.nodes:
54+
if (
55+
node.op != "call_function"
56+
or node.target != exir_ops.edge.aten.bmm.default
57+
):
58+
continue
59+
modified = True
60+
61+
x1, x2 = node.args
62+
tosa_matmul_target = exir_ops.backend.tosa.MATMUL.default
63+
with graph_module.graph.inserting_before(node):
64+
tosa_matmul_node = create_node(
65+
graph_module.graph,
66+
op_target=tosa_matmul_target,
67+
args=(x1, x2),
68+
kwargs={},
69+
from_node=node,
70+
)
71+
node.replace_all_uses_with(tosa_matmul_node)
72+
graph_module.graph.erase_node(node)
73+
74+
x1_fake_tensor = get_first_fake_tensor(x1)
75+
x2_fake_tensor = get_first_fake_tensor(x2)
76+
output_fake_tensor = tosa_matmul_target(x1_fake_tensor, x2_fake_tensor)
77+
node_output_fake_tensor = get_first_fake_tensor(node)
78+
if (
79+
output_fake_tensor.dtype == torch.int32
80+
and node_output_fake_tensor.dtype == torch.int8
81+
):
82+
self._insert_output_rescale(graph_module, node, tosa_matmul_node)
83+
84+
if modified:
85+
graph_module.recompile()
86+
graph_module = super().call(graph_module).graph_module
87+
return PassResult(graph_module, modified)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
op_any,
1515
op_avg_pool2d,
1616
op_bitwise_not,
17-
op_bmm,
1817
op_cat,
1918
op_ceil,
2019
op_clamp,
@@ -33,6 +32,7 @@
3332
op_log,
3433
op_logical_not,
3534
op_lt,
35+
op_matmul,
3636
op_max_pool2d,
3737
op_maximum,
3838
op_minimum,

backends/arm/operators/op_bmm.py renamed to backends/arm/operators/op_matmul.py

Lines changed: 21 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1515
get_input_qparams,
16-
get_output_qparams,
1716
)
1817
from executorch.backends.arm.operators.node_visitor import (
1918
NodeVisitor,
@@ -26,20 +25,13 @@
2625
)
2726
from executorch.backends.arm.tosa import TosaSpecification
2827
from executorch.backends.arm.tosa.mapping import TosaArg
29-
from executorch.backends.arm.tosa.quant_utils import build_rescale
30-
from tosa.RoundingMode import RoundingMode # type: ignore
3128

3229

3330
@register_node_visitor
34-
class BMMVisitor(NodeVisitor):
35-
"""Provide a visitor that lowers ``aten.bmm`` to TOSA ``MATMUL``.
31+
class MatmulVisitor(NodeVisitor):
32+
"""Provide a visitor that serializes TOSA ``MATMUL``."""
3633

37-
INT8 accumulates into INT32; add a rescale to INT8 using SINGLE_ROUND
38-
rounding and output zero-point.
39-
40-
"""
41-
42-
target = "aten.bmm.default"
34+
target = "tosa.MATMUL.default"
4335

4436
tosa_specs = [
4537
TosaSpecification.create_from_string("TOSA-1.0+INT"),
@@ -56,35 +48,36 @@ def define_node(
5648
inputs: List[TosaArg],
5749
output: TosaArg,
5850
) -> None:
59-
"""Define the TOSA ``MATMUL`` operator and optional rescale."""
51+
"""Define the TOSA ``MATMUL`` operator."""
6052
import serializer.tosa_serializer as ts # type: ignore
6153

6254
validate_num_inputs(self.target, inputs, 2)
63-
validate_same_dtype(self.target, [*inputs, output], ts)
55+
validate_same_dtype(self.target, [*inputs], ts)
6456
validate_valid_dtype(
6557
self.target,
66-
[*inputs, output],
67-
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
58+
[*inputs],
59+
[ts.DType.INT8, ts.DType.FP32],
60+
output.tosa_spec,
61+
)
62+
validate_valid_dtype(
63+
self.target,
64+
[output],
65+
[ts.DType.INT32, ts.DType.FP32],
6866
output.tosa_spec,
6967
)
7068

71-
# aten.bmm maps directly to MATMUL
72-
73-
# For INT8, we need to get the zero points and add an intermediate tensor
74-
# for a later rescale.
75-
69+
# We need to get the zero points and add an intermediate tensor
7670
if inputs[0].dtype == ts.DType.INT8:
7771
input_qparams = get_input_qparams(node)
7872
input0_zp = input_qparams[0].get_zp_per_tensor()
7973
input1_zp = input_qparams[1].get_zp_per_tensor()
80-
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
81-
bmm_output_name = bmm_result.name
8274
else:
83-
bmm_output_name = output.name
8475
input0_zp, input1_zp = 0, 0
8576

86-
tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=f"{node.name}_A_ZP")
87-
tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=f"{node.name}_B_ZP")
77+
input_A_ZP_name = f"{node.name}_A_ZP"
78+
input_B_ZP_name = f"{node.name}_B_ZP"
79+
tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=input_A_ZP_name)
80+
tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=input_B_ZP_name)
8881

8982
# Add the MATMUL to the TOSA graph.
9083
self._serialize_operator(
@@ -94,27 +87,8 @@ def define_node(
9487
[
9588
inputs[0].name,
9689
inputs[1].name,
97-
f"{node.name}_A_ZP",
98-
f"{node.name}_B_ZP",
90+
input_A_ZP_name,
91+
input_B_ZP_name,
9992
],
100-
[bmm_output_name],
93+
[output.name],
10194
)
102-
103-
# As INT8 accumulates into INT32, we need to rescale it back to INT8
104-
if output.dtype == ts.DType.INT8:
105-
output_qparams = get_output_qparams(node)[0]
106-
final_output_scale = (
107-
input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61]
108-
) / output_qparams.get_scale_per_tensor()
109-
110-
build_rescale(
111-
tosa_fb=tosa_graph,
112-
scale=[final_output_scale],
113-
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
114-
input_node=bmm_result, # type: ignore[possibly-undefined]
115-
output_name=output.name,
116-
output_type=ts.DType.INT8,
117-
input_zp=[0],
118-
output_zp=[output_qparams.get_zp_per_tensor()],
119-
rounding_mode=RoundingMode.SINGLE_ROUND,
120-
)

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401
7+
matmul,
78
rescale,
89
resize,
910
table,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
8+
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
9+
10+
from executorch.backends.arm.tosa.specification import (
11+
get_context_spec,
12+
TosaSpecification,
13+
)
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
17+
@register_fake_tosa_op(
18+
"MATMUL(Tensor input1, Tensor input2) -> Tensor", # schema
19+
(
20+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
21+
), # target TOSA specifications
22+
)
23+
def MATMUL(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
24+
tosa_spec = get_context_spec()
25+
"""Performs matrix multiplication on two input tensors.
26+
Additionally validates TOSA constraints of a MATMUL op.
27+
"""
28+
if x1.dtype != x2.dtype:
29+
raise TosaValueError(
30+
f"Input tensors must have the same dtype, got {x1.dtype} and {x2.dtype}",
31+
op="MATMUL",
32+
)
33+
if x1.dtype in (torch.int8,):
34+
if not tosa_spec.support_integer():
35+
raise TosaValueError(
36+
f"TOSA spec {tosa_spec} doesn't support integers", op="MATMUL"
37+
)
38+
else:
39+
dtype = torch.int32
40+
elif x1.dtype in (torch.float16, torch.float32):
41+
if not tosa_spec.support_float():
42+
raise TosaValueError(
43+
f"TOSA spec {tosa_spec} doesn't support float", op="MATMUL"
44+
)
45+
else:
46+
# float16 supports float16 accumulation as well
47+
dtype = torch.float32
48+
else:
49+
raise TosaValueError(
50+
f"Input tensors must be of type int8, float16 or float32, got {x1.dtype}",
51+
op="MATMUL",
52+
)
53+
54+
aten_fake_tensor = exir_ops.edge.aten.bmm.default(x1, x2)
55+
56+
return torch.empty_like(aten_fake_tensor, dtype=dtype)

0 commit comments

Comments
 (0)