Skip to content

Commit d8a00e6

Browse files
Add mul-op for Arm backend
Differential Revision: D61341327 Pull Request resolved: #4730
1 parent b66d62a commit d8a00e6

File tree

5 files changed

+240
-0
lines changed

5 files changed

+240
-0
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4545
exir_ops.edge.aten.div.Tensor,
4646
exir_ops.edge.aten.split_with_sizes_copy.default,
4747
exir_ops.edge.aten.full.default,
48+
exir_ops.edge.aten.mul.Tensor,
4849
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
4950
exir_ops.edge.aten.avg_pool2d.default,
5051
exir_ops.edge.aten.sigmoid.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
op_hardtanh,
1818
op_mean_dim,
1919
op_mm,
20+
op_mul,
2021
op_permute,
2122
op_quant,
2223
op_repeat,

backends/arm/operators/op_mul.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import List
7+
8+
import executorch.backends.arm.tosa_quant_utils as tqutils
9+
import executorch.backends.arm.tosa_utils as tutils
10+
11+
import serializer.tosa_serializer as ts
12+
import torch
13+
14+
from executorch.backends.arm.operators.node_visitor import (
15+
NodeVisitor,
16+
register_node_visitor,
17+
)
18+
from executorch.backends.arm.tosa_mapping import TosaArg
19+
from serializer.tosa_serializer import TosaOp
20+
21+
22+
@register_node_visitor
23+
class MulVisitor(NodeVisitor):
24+
target = "aten.mul.Tensor"
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
tosa_graph: ts.TosaSerializer,
30+
inputs: List[TosaArg],
31+
output: TosaArg,
32+
is_quant_node: bool,
33+
) -> None:
34+
35+
if is_quant_node:
36+
input_A = inputs[0]
37+
input_B = inputs[1]
38+
input_A_qargs = tqutils.get_quant_node_args(node.args[0])
39+
input_B_qargs = tqutils.get_quant_node_args(node.args[1])
40+
41+
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
42+
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
43+
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
44+
45+
# Rescale inputs to INT32 with zp=0
46+
input_A_rescaled = tqutils.build_rescale_to_int32(
47+
tosa_graph,
48+
input_A,
49+
input_A_qargs.zp,
50+
rescale_scale=1.0,
51+
)
52+
input_B_rescaled = tqutils.build_rescale_to_int32(
53+
tosa_graph,
54+
input_B,
55+
input_B_qargs.zp,
56+
rescale_scale=1.0,
57+
)
58+
59+
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
60+
61+
# Do the INT32 Mul
62+
attr = ts.TosaSerializerAttribute()
63+
attr.MulAttribute(shift=0)
64+
tosa_graph.addOperator(
65+
TosaOp.Op().MUL,
66+
[
67+
input_A_rescaled.name,
68+
input_B_rescaled.name,
69+
],
70+
[mul_output.name],
71+
attr,
72+
)
73+
74+
tqutils.rescale_node_back_to_int8(
75+
node, mul_output, input_A_qargs.scale * input_B_qargs.scale, tosa_graph
76+
)
77+
78+
else:
79+
attr = ts.TosaSerializerAttribute()
80+
attr.MulAttribute(shift=0)
81+
tosa_graph.addOperator(
82+
TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr
83+
)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern
7373
[torch.nn.AdaptiveAvgPool2d],
7474
[F.adaptive_avg_pool2d],
7575
],
76+
"mul": [torch.mul],
7677
"sub": [[torch.sub]],
7778
}
7879
return copy.deepcopy(supported_operators)

backends/arm/test/ops/test_mul.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
10+
import torch
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
from parameterized import parameterized
14+
15+
test_data_sute = [
16+
# (test_name, input, other,) See torch.mul() for info
17+
(
18+
"op_mul_rank1_ones",
19+
torch.ones(5),
20+
torch.ones(5),
21+
),
22+
(
23+
"op_mul_rank2_rand",
24+
torch.rand(4, 5),
25+
torch.rand(1, 5),
26+
),
27+
(
28+
"op_mul_rank3_randn",
29+
torch.randn(10, 5, 2),
30+
torch.randn(10, 5, 2),
31+
),
32+
(
33+
"op_mul_rank4_randn",
34+
torch.randn(5, 10, 25, 20),
35+
torch.randn(5, 10, 25, 20),
36+
),
37+
(
38+
"op_mul_rank4_ones_mul_negative",
39+
torch.ones(1, 10, 25, 20),
40+
(-1) * torch.ones(5, 10, 25, 20),
41+
),
42+
(
43+
"op_mul_rank4_negative_large_rand",
44+
(-200) * torch.rand(5, 10, 25, 20),
45+
torch.rand(5, 1, 1, 20),
46+
),
47+
(
48+
"op_mul_rank4_large_randn",
49+
200 * torch.randn(5, 10, 25, 20),
50+
torch.rand(5, 10, 25, 1),
51+
),
52+
]
53+
54+
55+
class TestMul(unittest.TestCase):
56+
class Mul(torch.nn.Module):
57+
58+
def forward(
59+
self,
60+
input_: torch.Tensor,
61+
other_: torch.Tensor,
62+
):
63+
return input_ * other_
64+
65+
def _test_mul_tosa_MI_pipeline(
66+
self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor]
67+
):
68+
(
69+
ArmTester(
70+
module,
71+
example_inputs=test_data,
72+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
73+
)
74+
.export()
75+
.check_count({"torch.ops.aten.mul.Tensor": 1})
76+
.check_not(["torch.ops.quantized_decomposed"])
77+
.to_edge()
78+
.partition()
79+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
80+
.to_executorch()
81+
.run_method_and_compare_outputs(inputs=test_data)
82+
)
83+
84+
def _test_mul_tosa_BI_pipeline(
85+
self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor]
86+
):
87+
(
88+
ArmTester(
89+
module,
90+
example_inputs=test_data,
91+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
92+
)
93+
.quantize()
94+
.export()
95+
.check_count({"torch.ops.aten.mul.Tensor": 1})
96+
.check(["torch.ops.quantized_decomposed"])
97+
.to_edge()
98+
.partition()
99+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
100+
.to_executorch()
101+
.run_method_and_compare_outputs(inputs=test_data, qtol=1.0)
102+
)
103+
104+
def _test_mul_u55_BI_pipeline(
105+
self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor]
106+
):
107+
(
108+
ArmTester(
109+
module,
110+
example_inputs=test_data,
111+
compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True),
112+
)
113+
.quantize()
114+
.export()
115+
.check_count({"torch.ops.aten.mul.Tensor": 1})
116+
.check(["torch.ops.quantized_decomposed"])
117+
.to_edge()
118+
.partition()
119+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
120+
.to_executorch()
121+
)
122+
123+
@parameterized.expand(test_data_sute)
124+
def test_mul_tosa_MI(
125+
self,
126+
test_name: str,
127+
input_: torch.Tensor,
128+
other_: torch.Tensor,
129+
):
130+
test_data = (input_, other_)
131+
self._test_mul_tosa_MI_pipeline(self.Mul(), test_data)
132+
133+
@parameterized.expand(test_data_sute)
134+
def test_mul_tosa_BI(
135+
self,
136+
test_name: str,
137+
input_: torch.Tensor,
138+
other_: torch.Tensor,
139+
):
140+
141+
test_data = (input_, other_)
142+
self._test_mul_tosa_BI_pipeline(self.Mul(), test_data)
143+
144+
# Expected to fail since RESCALE cannot be fused with MUL in Vela.
145+
@parameterized.expand(test_data_sute)
146+
@unittest.expectedFailure
147+
def test_mul_u55_BI(
148+
self,
149+
test_name: str,
150+
input_: torch.Tensor,
151+
other_: torch.Tensor,
152+
):
153+
test_data = (input_, other_)
154+
self._test_mul_u55_BI_pipeline(self.Mul(), test_data)

0 commit comments

Comments
 (0)