Skip to content

Commit 43efc37

Browse files
authored
Arm: Add pass that converts operators to clamp to the Arm backend (#8538)
Add pass that converts operators to clamp to the Arm backend * Add ConvertToClampPass that converts relu and hardtanh to clamp * Remove op_relu and op_hardtanh visitors from backend Signed-off-by: Tom Allsop <[email protected]>
1 parent 4b1ae21 commit 43efc37

File tree

6 files changed

+120
-127
lines changed

6 files changed

+120
-127
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found]
2828
ConvertSqueezesToViewPass,
2929
)
30+
from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass
3031
from executorch.backends.arm._passes.decompose_batchnorm_pass import (
3132
DecomposeBatchNormPass,
3233
)
@@ -104,6 +105,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
104105
self.add_pass(DecomposeLinearPass())
105106
self.add_pass(ConvertMeanDimToAveragePoolPass())
106107
self.add_pass(ConvertFullLikeToFullPass())
108+
self.add_pass(ConvertToClampPass())
107109

108110
self.add_pass(ReplaceScalarWithTensorArgPass())
109111
self.add_pass(AnnotateDecomposedMatmulPass())
@@ -144,6 +146,8 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
144146
self.add_pass(DecomposeDivPass())
145147
self.add_pass(DecomposeSoftmaxesPass())
146148
self.add_pass(ConvertFullLikeToFullPass())
149+
self.add_pass(ConvertToClampPass())
150+
147151
self.add_pass(AnnotateDecomposedMatmulPass())
148152
self.add_pass(QuantizeOperatorArguments())
149153
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
edge_operators = {
12+
exir_ops.edge.aten.hardtanh.default,
13+
exir_ops.edge.aten.relu.default,
14+
}
15+
16+
17+
def get_clamp_params(op, args) -> Tuple[float | None, float | None]:
18+
if op == exir_ops.edge.aten.hardtanh.default:
19+
return args[1], args[2]
20+
elif op == exir_ops.edge.aten.relu.default:
21+
return 0.0, None
22+
else:
23+
raise ValueError(f"Getting clamp parameters for op {op} is not implemented.")
24+
25+
26+
class ConvertToClampPass(ExportPass):
27+
def call_operator(self, op, args, kwargs, meta):
28+
if op not in edge_operators:
29+
return super().call_operator(op, args, kwargs, meta)
30+
31+
return super().call_operator(
32+
exir_ops.edge.aten.clamp.default,
33+
(args[0], *get_clamp_params(op, args)),
34+
{},
35+
meta,
36+
)

backends/arm/operators/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
op_ge,
2222
op_get_item,
2323
op_gt,
24-
op_hardtanh,
2524
op_le,
2625
op_log,
2726
op_lt,
@@ -31,7 +30,6 @@
3130
op_mul,
3231
op_permute,
3332
op_reciprocal,
34-
op_relu,
3533
op_repeat,
3634
op_rescale,
3735
op_rshift,

backends/arm/operators/op_hardtanh.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

backends/arm/operators/op_relu.py

Lines changed: 0 additions & 59 deletions
This file was deleted.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import torch
9+
from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass
10+
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
14+
from executorch.backends.xnnpack.test.tester.tester import RunPasses
15+
16+
17+
class HardTanh(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
self.hardtanh = torch.nn.Hardtanh()
22+
23+
def forward(self, x):
24+
return self.hardtanh(x)
25+
26+
def get_inputs(self):
27+
return (torch.rand(1, 64, 64, 3),)
28+
29+
30+
class ReLU(torch.nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
34+
self.relu = torch.nn.ReLU()
35+
36+
def forward(self, x):
37+
return self.relu(x)
38+
39+
def get_inputs(self):
40+
return (torch.rand(1, 64, 64, 3),)
41+
42+
43+
class TestConvertToClampPass(unittest.TestCase):
44+
"""
45+
Tests the ConvertToClampPass which converts hardtanh.default and relu.default to clamp.default
46+
"""
47+
48+
def test_tosa_MI_hardtahn(self):
49+
module = HardTanh()
50+
test_pass_stage = RunPasses([ConvertToClampPass])
51+
(
52+
ArmTester(
53+
module,
54+
example_inputs=module.get_inputs(),
55+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
56+
)
57+
.export()
58+
.to_edge()
59+
.check(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
60+
.run_passes(test_pass_stage)
61+
.check(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
62+
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
63+
)
64+
65+
def test_tosa_MI_relu(self):
66+
module = ReLU()
67+
test_pass_stage = RunPasses([ConvertToClampPass])
68+
(
69+
ArmTester(
70+
module,
71+
example_inputs=module.get_inputs(),
72+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
73+
)
74+
.export()
75+
.to_edge()
76+
.check(["executorch_exir_dialects_edge__ops_aten_relu_default"])
77+
.run_passes(test_pass_stage)
78+
.check(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
79+
.check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
80+
)

0 commit comments

Comments
 (0)