Skip to content

Commit df441f5

Browse files
committed
Add clamp operator to Arm backend
* Add support for aten.clamp * Amend QuantizeFullArgument pass to include quantization of clamp arguments Signed-off-by: Tom Allsop <[email protected]> Change-Id: I432f4ec60facc50fe45ca05c98308924d6e18109
1 parent debafbe commit df441f5

File tree

7 files changed

+346
-15
lines changed

7 files changed

+346
-15
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
3838
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
3939
FoldAndAnnotateQParamsPass,
40-
QuantizeFullArgument,
40+
QuantizeOperatorArguments,
4141
RetraceFoldedDtypesPass,
4242
)
4343
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
@@ -88,7 +88,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8888
self.add_pass(ConvertMeanDimToAveragePoolPass())
8989

9090
self.add_pass(AnnotateDecomposedMatmulPass())
91-
self.add_pass(QuantizeFullArgument())
91+
self.add_pass(QuantizeOperatorArguments())
9292
self.add_pass(FoldAndAnnotateQParamsPass())
9393
self.add_pass(RetraceFoldedDtypesPass())
9494
self.add_pass(InsertTableOpsPass(exported_program))
@@ -124,7 +124,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
124124
self.add_pass(DecomposeSoftmaxesPass())
125125

126126
self.add_pass(AnnotateDecomposedMatmulPass())
127-
self.add_pass(QuantizeFullArgument())
127+
self.add_pass(QuantizeOperatorArguments())
128128
self.add_pass(FoldAndAnnotateQParamsPass())
129129
self.add_pass(RetraceFoldedDtypesPass())
130130
self.add_pass(InsertTableOpsPass(exported_program))

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,19 +182,25 @@ def call(self, graph_module: GraphModule) -> PassResult:
182182
return PassResult(graph_module, True)
183183

184184

185-
class QuantizeFullArgument(ExportPass):
185+
class QuantizeOperatorArguments(ExportPass):
186186
"""
187-
Make sure the fill_value for full.default is quantized. This pass needs to be run before
188-
the folding pass above to make sure that the retraced output of the full.default op is
189-
the right dtype.
187+
This pass makes sure that the arguments to full.default and clamp.default are quantized correctly.
188+
More specifically, this pass:
189+
- Makes sure the fill_value for full.default is quantized. This pass needs to be run before
190+
the folding pass above to make sure that the retraced output of the full.default op is
191+
the right dtype.
192+
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
190193
"""
191194

192195
def call(self, graph_module: GraphModule) -> PassResult:
193196
modified = False
194197
# Loop over the graph nodes and find full.default nodes.
195198
for n in graph_module.graph.nodes:
196199
n = cast(Node, n)
197-
if n.target != exir_ops.edge.aten.full.default:
200+
if n.target not in {
201+
exir_ops.edge.aten.clamp.default,
202+
exir_ops.edge.aten.full.default,
203+
}:
198204
continue
199205

200206
# Make sure we have a quantized operator
@@ -203,13 +209,29 @@ def call(self, graph_module: GraphModule) -> PassResult:
203209
continue
204210

205211
qargs = QuantArgs.from_operator(user.target, user.args)
206-
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
207-
# replace the node arg with a quantized dito and also set dtype
208-
# to get the right output according to the Edge IR specification:
209-
# exir/dialects/edge/edge.yaml:3596
210-
quantized_full_value = qargs.quantize_value(n.args[1]).item()
211-
n.update_arg(1, quantized_full_value)
212-
n.update_kwarg("dtype", qargs.dtype)
212+
213+
if n.target == exir_ops.edge.aten.full.default:
214+
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
215+
# replace the node arg with a quantized dito and also set dtype
216+
# to get the right output according to the Edge IR specification:
217+
# exir/dialects/edge/edge.yaml:3596
218+
quantized_full_value = qargs.quantize_value(n.args[1]).item()
219+
n.update_arg(1, quantized_full_value)
220+
n.update_kwarg("dtype", qargs.dtype)
221+
modified = True
222+
elif n.target == exir_ops.edge.aten.clamp.default:
223+
# Quantize the min and max arguments of clamp, if they are not None
224+
min_val = n.args[1]
225+
max_val = None if len(n.args) <= 2 else n.args[2]
226+
227+
if min_val is not None:
228+
quantized_min_val = qargs.quantize_value(min_val).item()
229+
n.update_arg(1, quantized_min_val)
230+
231+
if max_val is not None:
232+
quantized_max_val = qargs.quantize_value(max_val).item()
233+
n.update_arg(2, quantized_max_val)
234+
213235
modified = True
214236

215237
return PassResult(graph_module, modified)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
7676
exir_ops.edge.aten.add.Tensor,
7777
exir_ops.edge.aten.expand_copy.default,
7878
exir_ops.edge.aten.cat.default,
79+
exir_ops.edge.aten.clamp.default,
7980
exir_ops.edge.aten.bmm.default,
8081
exir_ops.edge.aten.permute_copy.default,
8182
exir_ops.edge.aten.hardtanh.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
op_batch_norm,
1313
op_bmm,
1414
op_cat,
15+
op_clamp,
1516
op_conv2d,
1617
op_eq,
1718
op_exp,

backends/arm/operators/op_clamp.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree
6+
7+
from numbers import Number
8+
from typing import List, Tuple
9+
10+
import serializer.tosa_serializer as ts
11+
12+
import torch
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
18+
from executorch.backends.arm.tosa_mapping import TosaArg
19+
from executorch.backends.arm.tosa_specification import TosaSpecification
20+
from serializer.tosa_serializer import TosaOp
21+
from torch.fx import Node
22+
23+
24+
@register_node_visitor
25+
class ClampVisitor_080_BI(NodeVisitor):
26+
target = "aten.clamp.default"
27+
28+
tosa_specs = [
29+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
30+
]
31+
32+
def __init__(self, *args):
33+
super().__init__(*args)
34+
35+
def _create_clamp_node(
36+
self,
37+
tosa_graph: ts.TosaSerializer,
38+
input_name: str,
39+
output_name: str,
40+
min_int: int,
41+
max_int: int,
42+
min_fp32: float,
43+
max_fp32: float,
44+
):
45+
attr = ts.TosaSerializerAttribute()
46+
attr.ClampAttribute(
47+
tosa_graph.builder,
48+
min_int,
49+
max_int,
50+
min_fp32,
51+
max_fp32,
52+
)
53+
tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr)
54+
55+
def _get_min_max_arguments(
56+
self,
57+
node: Node,
58+
dtype_min: Number,
59+
dtype_max: Number,
60+
) -> Tuple[Number, Number]:
61+
assert 2 <= len(node.args) <= 3
62+
63+
min_arg = dtype_min
64+
max_arg = dtype_max
65+
66+
if node.args[1] is not None:
67+
min_arg = node.args[1]
68+
69+
if len(node.args) > 2:
70+
if node.args[2] is not None:
71+
max_arg = node.args[2]
72+
73+
return min_arg, max_arg
74+
75+
def define_node(
76+
self,
77+
node: Node,
78+
tosa_graph: ts.TosaSerializer,
79+
inputs: List[TosaArg],
80+
output: TosaArg,
81+
) -> None:
82+
assert len(node.all_input_nodes) == 1
83+
84+
min_int8, max_int8 = self._get_min_max_arguments(
85+
node,
86+
torch.iinfo(torch.int8).min,
87+
torch.iinfo(torch.int8).max,
88+
)
89+
90+
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
91+
self._create_clamp_node(
92+
tosa_graph,
93+
inputs[0].name,
94+
output.name,
95+
min_int8,
96+
max_int8,
97+
0,
98+
0,
99+
)
100+
101+
102+
@register_node_visitor
103+
class ClampVisitor_080_MI(ClampVisitor_080_BI):
104+
# inheriting 'target' from BI class
105+
106+
tosa_specs = [
107+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
108+
]
109+
110+
def __init__(self, *args):
111+
super().__init__(*args)
112+
113+
def define_node(
114+
self,
115+
node: Node,
116+
tosa_graph: ts.TosaSerializer,
117+
inputs: List[TosaArg],
118+
output: TosaArg,
119+
) -> None:
120+
assert len(node.all_input_nodes) == 1
121+
122+
if inputs[0].dtype == ts.DType.INT8:
123+
# Call the inherited define_node for handling integers
124+
super().define_node(node, tosa_graph, inputs, output)
125+
else:
126+
min_fp32, max_fp32 = self._get_min_max_arguments(
127+
node,
128+
torch.finfo(torch.float32).min,
129+
torch.finfo(torch.float32).max,
130+
)
131+
132+
self._create_clamp_node(
133+
tosa_graph,
134+
inputs[0].name,
135+
output.name,
136+
0,
137+
0,
138+
min_fp32,
139+
max_fp32,
140+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def _match_pattern(
186186
torch.ops.aten.full.default,
187187
torch.ops.aten.flatten.using_ints,
188188
torch.ops.aten.dropout.default,
189+
torch.ops.aten.clamp.default,
190+
torch.ops.aten.clamp.Tensor,
189191
operator.getitem,
190192
]
191193

0 commit comments

Comments
 (0)