Skip to content

Commit 4a8fb17

Browse files
authored
Add clamp operator to Arm backend (#7935)
* Add clamp operator to Arm backend * Add support for aten.clamp * Amend QuantizeFullArgument pass to include quantization of clamp arguments Signed-off-by: Tom Allsop <[email protected]> Change-Id: I432f4ec60facc50fe45ca05c98308924d6e18109 * Address merge conflicts and typing issues Signed-off-by: Tom Allsop <[email protected]> Change-Id: If9f6bbb9d23f00b9cea6d00c6aa97a6d5d3e77ab --------- Signed-off-by: Tom Allsop <[email protected]>
1 parent b71b25b commit 4a8fb17

File tree

7 files changed

+350
-15
lines changed

7 files changed

+350
-15
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
4040
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
4141
FoldAndAnnotateQParamsPass,
42-
QuantizeFullArgument,
42+
QuantizeOperatorArguments,
4343
RetraceFoldedDtypesPass,
4444
)
4545
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
@@ -92,7 +92,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
9292
self.add_pass(ConvertMeanDimToAveragePoolPass())
9393

9494
self.add_pass(AnnotateDecomposedMatmulPass())
95-
self.add_pass(QuantizeFullArgument())
95+
self.add_pass(QuantizeOperatorArguments())
9696
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
9797
self.add_pass(RetraceFoldedDtypesPass())
9898
self.add_pass(InsertTableOpsPass(exported_program))
@@ -128,7 +128,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
128128
self.add_pass(DecomposeSoftmaxesPass())
129129

130130
self.add_pass(AnnotateDecomposedMatmulPass())
131-
self.add_pass(QuantizeFullArgument())
131+
self.add_pass(QuantizeOperatorArguments())
132132
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
133133
self.add_pass(RetraceFoldedDtypesPass())
134134
self.add_pass(InsertTableOpsPass(exported_program))

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,25 @@ def call(self, graph_module: GraphModule) -> PassResult:
167167
return PassResult(graph_module, True)
168168

169169

170-
class QuantizeFullArgument(ExportPass):
170+
class QuantizeOperatorArguments(ExportPass):
171171
"""
172-
Make sure the fill_value for full.default is quantized. This pass needs to be run before
173-
the folding pass above to make sure that the retraced output of the full.default op is
174-
the right dtype.
172+
This pass makes sure that the arguments to full.default and clamp.default are quantized correctly.
173+
More specifically, this pass:
174+
- Makes sure the fill_value for full.default is quantized. This pass needs to be run before
175+
the folding pass above to make sure that the retraced output of the full.default op is
176+
the right dtype.
177+
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
175178
"""
176179

177180
def call(self, graph_module: GraphModule) -> PassResult:
178181
modified = False
179182
# Loop over the graph nodes and find full.default nodes.
180183
for n in graph_module.graph.nodes:
181184
n = cast(Node, n)
182-
if n.target != exir_ops.edge.aten.full.default:
185+
if n.target not in {
186+
exir_ops.edge.aten.clamp.default,
187+
exir_ops.edge.aten.full.default,
188+
}:
183189
continue
184190

185191
# Make sure we have a quantized operator
@@ -188,13 +194,29 @@ def call(self, graph_module: GraphModule) -> PassResult:
188194
continue
189195

190196
qargs = QuantArgs.from_operator(user.target, user.args)
191-
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
192-
# replace the node arg with a quantized dito and also set dtype
193-
# to get the right output according to the Edge IR specification:
194-
# exir/dialects/edge/edge.yaml:3596
195-
quantized_full_value = qargs.quantize_value(n.args[1]).item()
196-
n.update_arg(1, quantized_full_value)
197-
n.update_kwarg("dtype", qargs.dtype)
197+
198+
if n.target == exir_ops.edge.aten.full.default:
199+
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
200+
# replace the node arg with a quantized dito and also set dtype
201+
# to get the right output according to the Edge IR specification:
202+
# exir/dialects/edge/edge.yaml:3596
203+
quantized_full_value = qargs.quantize_value(n.args[1]).item()
204+
n.update_arg(1, quantized_full_value)
205+
n.update_kwarg("dtype", qargs.dtype)
206+
modified = True
207+
elif n.target == exir_ops.edge.aten.clamp.default:
208+
# Quantize the min and max arguments of clamp, if they are not None
209+
min_val = n.args[1]
210+
max_val = None if len(n.args) <= 2 else n.args[2]
211+
212+
if min_val is not None:
213+
quantized_min_val = qargs.quantize_value(min_val).item()
214+
n.update_arg(1, quantized_min_val)
215+
216+
if max_val is not None:
217+
quantized_max_val = qargs.quantize_value(max_val).item()
218+
n.update_arg(2, quantized_max_val)
219+
198220
modified = True
199221

200222
return PassResult(graph_module, modified)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
7676
exir_ops.edge.aten.add.Tensor,
7777
exir_ops.edge.aten.expand_copy.default,
7878
exir_ops.edge.aten.cat.default,
79+
exir_ops.edge.aten.clamp.default,
7980
exir_ops.edge.aten.bmm.default,
8081
exir_ops.edge.aten.permute_copy.default,
8182
exir_ops.edge.aten.hardtanh.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
op_batch_norm,
1313
op_bmm,
1414
op_cat,
15+
op_clamp,
1516
op_conv2d,
1617
op_eq,
1718
op_exp,

backends/arm/operators/op_clamp.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree
6+
7+
from typing import Any, List, Tuple
8+
9+
import serializer.tosa_serializer as ts # type: ignore
10+
11+
import torch
12+
from executorch.backends.arm.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from executorch.backends.arm.tosa_specification import TosaSpecification
19+
from serializer.tosa_serializer import TosaOp
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class ClampVisitor_080_BI(NodeVisitor):
25+
target = "aten.clamp.default"
26+
27+
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
29+
]
30+
31+
def __init__(self, *args):
32+
super().__init__(*args)
33+
34+
def _create_clamp_node(
35+
self,
36+
tosa_graph: ts.TosaSerializer,
37+
input_name: str,
38+
output_name: str,
39+
min_int: int,
40+
max_int: int,
41+
min_fp32: float,
42+
max_fp32: float,
43+
) -> None:
44+
attr = ts.TosaSerializerAttribute()
45+
attr.ClampAttribute(
46+
tosa_graph.builder,
47+
min_int,
48+
max_int,
49+
min_fp32,
50+
max_fp32,
51+
)
52+
tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr)
53+
54+
def _get_min_max_arguments(
55+
self, node: Node, dtype_min: int | float, dtype_max: int | float
56+
) -> Tuple[int | float, int | float]:
57+
58+
def cast_type(value: Any) -> int | float:
59+
if isinstance(value, int):
60+
return value
61+
else:
62+
# Attempt to cast to float
63+
return float(value)
64+
65+
assert 2 <= len(node.args) <= 3
66+
67+
min_arg = dtype_min
68+
max_arg = dtype_max
69+
70+
if node.args[1] is not None:
71+
min_arg = cast_type(node.args[1])
72+
73+
if len(node.args) > 2:
74+
if node.args[2] is not None:
75+
max_arg = cast_type(node.args[2])
76+
77+
return min_arg, max_arg
78+
79+
def define_node(
80+
self,
81+
node: Node,
82+
tosa_graph: ts.TosaSerializer,
83+
inputs: List[TosaArg],
84+
output: TosaArg,
85+
) -> None:
86+
assert len(node.all_input_nodes) == 1
87+
88+
min_int8, max_int8 = self._get_min_max_arguments(
89+
node,
90+
torch.iinfo(torch.int8).min,
91+
torch.iinfo(torch.int8).max,
92+
)
93+
94+
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
95+
self._create_clamp_node(
96+
tosa_graph,
97+
inputs[0].name,
98+
output.name,
99+
int(min_int8),
100+
int(max_int8),
101+
0,
102+
0,
103+
)
104+
105+
106+
@register_node_visitor
107+
class ClampVisitor_080_MI(ClampVisitor_080_BI):
108+
# inheriting 'target' from BI class
109+
110+
tosa_specs = [
111+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
112+
]
113+
114+
def __init__(self, *args):
115+
super().__init__(*args)
116+
117+
def define_node(
118+
self,
119+
node: Node,
120+
tosa_graph: ts.TosaSerializer,
121+
inputs: List[TosaArg],
122+
output: TosaArg,
123+
) -> None:
124+
assert len(node.all_input_nodes) == 1
125+
126+
if inputs[0].dtype == ts.DType.INT8:
127+
# Call the inherited define_node for handling integers
128+
super().define_node(node, tosa_graph, inputs, output)
129+
else:
130+
min_fp32, max_fp32 = self._get_min_max_arguments(
131+
node,
132+
torch.finfo(torch.float32).min,
133+
torch.finfo(torch.float32).max,
134+
)
135+
136+
self._create_clamp_node(
137+
tosa_graph,
138+
inputs[0].name,
139+
output.name,
140+
0,
141+
0,
142+
min_fp32,
143+
max_fp32,
144+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def _match_pattern(
186186
torch.ops.aten.full.default,
187187
torch.ops.aten.flatten.using_ints,
188188
torch.ops.aten.dropout.default,
189+
torch.ops.aten.clamp.default,
190+
torch.ops.aten.clamp.Tensor,
189191
operator.getitem,
190192
]
191193

0 commit comments

Comments
 (0)