Skip to content

Commit 2c603e4

Browse files
Arm backend: Move rescale ops out of node visitors (#14584)
Some TOSA ops do not support INT8 as inputs and outputs. Instead, only INT32 is supported as a whole number type. Prior to this patch, affected node visitors inserted rescale ops between the data types INT8 and INT32 before and after the operator such that it will accept its input and output. Change this by moving the insertion of the rescale ops to a new pass called `InsertRescaleInt32Pass`. This will further enable optimizations to the graph by fusing the rescale nodes. Only comparison, ABS, MAXIMUM and MINIMUM operators are handled in this patch; the remaining ones are left out to be done in another patch. ### Test plan This is refactoring which means that external behavior is not altered. A new pass `InsertRescaleInt32Pass` has been added and it comes with a new unit test in backends/arm/test/passes/test_insert_rescale_i32_pass.py. Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Oscar Andersson <[email protected]>
1 parent 0b748bf commit 2c603e4

File tree

12 files changed

+341
-238
lines changed

12 files changed

+341
-238
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from .insert_int32_casts_after_int64_placeholders import ( # noqa
8282
InsertInt32CastsAfterInt64PlaceholdersPass,
8383
)
84-
from .insert_rescales_pass import InsertRescalePass # noqa
84+
from .insert_rescales_pass import InsertRescaleInt32Pass, InsertRescalePass # noqa
8585
from .insert_table_ops import InsertTableOpsPass # noqa
8686
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
8787
from .match_arg_ranks_pass import MatchArgRanksPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
FuseEqualPlaceholdersPass,
8282
FuseQuantizedActivationPass,
8383
InsertInt32CastsAfterInt64PlaceholdersPass,
84+
InsertRescaleInt32Pass,
8485
InsertRescalePass,
8586
InsertTableOpsPass,
8687
MatchArgDtypePass,
@@ -214,6 +215,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
214215
self.add_pass(ToTosaMemoryFormatPass(exported_program))
215216
self.add_pass(RemoveNoopPass())
216217
self.add_pass(InsertRescalePass())
218+
self.add_pass(InsertRescaleInt32Pass())
217219

218220
self.validate_constraints_mandatory()
219221
return self._transform(exported_program.graph_module)

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 238 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from copy import copy
7-
from typing import cast, Set, Type
7+
from typing import cast, Dict, Optional, Set, Tuple, Type
88

9-
from executorch.backends.arm._passes.arm_pass_utils import create_node
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg
12+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13+
get_output_qparams,
14+
)
1015
from executorch.backends.arm._passes.quant_args import QuantArgs
1116
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1217
from executorch.exir.dialects._ops import ops as exir_ops
@@ -65,3 +70,234 @@ def call(self, graph_module: GraphModule) -> PassResult:
6570
graph_module = super().call(graph_module).graph_module
6671
graph_module.recompile()
6772
return PassResult(graph_module, modified)
73+
74+
75+
class InsertRescaleInt32Pass(ArmPass):
76+
"""
77+
Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
78+
quantized implementations. This pass treats such operator nodes by
79+
inserting rescale ops before and after them if needed. Note that extra logic
80+
that handles the scales and zero points must be in place because the affected
81+
TOSA have naive implementations that do not account for the quantization
82+
parameters.
83+
"""
84+
85+
_passes_required_after: Set[Type[ExportPass]] = set()
86+
87+
included_targets = [
88+
exir_ops.edge.aten.abs.default,
89+
exir_ops.edge.aten.eq.Tensor,
90+
exir_ops.edge.aten.ge.Tensor,
91+
exir_ops.edge.aten.gt.Tensor,
92+
exir_ops.edge.aten.le.Tensor,
93+
exir_ops.edge.aten.lt.Tensor,
94+
exir_ops.edge.aten.maximum.default,
95+
exir_ops.edge.aten.minimum.default,
96+
]
97+
98+
def _int32_qargs(self, s):
99+
"""Helper creator function for INT32-based QuantArgs"""
100+
101+
return QuantArgs(
102+
scale=s,
103+
zp=0,
104+
qmin=torch.iinfo(torch.int32).min,
105+
qmax=torch.iinfo(torch.int32).max,
106+
dtype=torch.int32,
107+
)
108+
109+
def _get_inputs_rescaled_qparams(
110+
self, target, input_qparams: Dict[int, QuantArgs]
111+
) -> Dict[int, QuantArgs]:
112+
"""Get the qparams for the INT32 operands to the op ``target``
113+
114+
Inputs to the INT32-based operator must be rescaled from INT8 to INT32.
115+
This function computes the ``QuantArgs`` for each of the operands and returns
116+
it as a dict, mapping tensor index to ``QuantArgs``.
117+
"""
118+
119+
if target in [
120+
exir_ops.edge.aten.abs.default,
121+
exir_ops.edge.aten.eq.Tensor,
122+
exir_ops.edge.aten.ge.Tensor,
123+
exir_ops.edge.aten.gt.Tensor,
124+
exir_ops.edge.aten.le.Tensor,
125+
exir_ops.edge.aten.lt.Tensor,
126+
exir_ops.edge.aten.minimum.default,
127+
exir_ops.edge.aten.maximum.default,
128+
]:
129+
# For these ops, use the smallest scale among the INT8 operands.
130+
min_scale = min(
131+
[qp.get_scale_per_tensor() for qp in input_qparams.values()]
132+
)
133+
qparams = {
134+
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
135+
}
136+
else:
137+
raise ValueError(f"Not a valid target: {target}")
138+
139+
return qparams
140+
141+
def _get_output_qparams(
142+
self, target, inputs_qparams: Dict[int, QuantArgs]
143+
) -> Optional[QuantArgs]:
144+
"""Given an op ``target`` and the ``QuantArgs`` for each of its inputs, compute
145+
the scale of the output based on how the operator itself affects it."""
146+
147+
if target in [
148+
exir_ops.edge.aten.abs.default,
149+
exir_ops.edge.aten.maximum.default,
150+
exir_ops.edge.aten.minimum.default,
151+
]:
152+
# The op has not altered the scale; the output scale is equal to
153+
# the operands' scales.
154+
return self._int32_qargs(inputs_qparams[0].get_scale_per_tensor())
155+
elif target in [
156+
exir_ops.edge.aten.eq.Tensor,
157+
exir_ops.edge.aten.ge.Tensor,
158+
exir_ops.edge.aten.gt.Tensor,
159+
exir_ops.edge.aten.le.Tensor,
160+
exir_ops.edge.aten.lt.Tensor,
161+
]:
162+
# Output is bool for these ops and thus no qparams are present
163+
return None
164+
else:
165+
raise ValueError(f"Not a valid target: {target}")
166+
167+
def _get_rescale_qparams(
168+
self, target, input_qparams: Dict[int, QuantArgs]
169+
) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]:
170+
"""
171+
Get the quantization parameters of the INT32 inputs/outputs that will
172+
surround the node after the new RESCALE ops have been inserted.
173+
"""
174+
175+
inputs_rescaled_qparams = self._get_inputs_rescaled_qparams(
176+
target, input_qparams
177+
)
178+
output_qparams = self._get_output_qparams(target, inputs_rescaled_qparams)
179+
180+
return (inputs_rescaled_qparams, output_qparams)
181+
182+
def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool:
183+
qargs = node.meta["input_qparams"]
184+
185+
args_copy = list(node.args)
186+
seen_args = set()
187+
modified = False
188+
for i in qargs:
189+
qp = qargs[i]
190+
if qp.dtype != torch.int8:
191+
continue
192+
193+
arg_node = args_copy[i]
194+
if arg_node in seen_args:
195+
continue
196+
seen_args.add(arg_node)
197+
198+
with graph.inserting_after(arg_node):
199+
rescale_node = create_node(
200+
graph,
201+
exir_ops.backend.tosa.RESCALE.default,
202+
(
203+
arg_node,
204+
torch.int32,
205+
qp.get_scale_per_tensor()
206+
/ rescale_qargs[
207+
i
208+
].get_scale_per_tensor(), # Old scale / new scale
209+
qp.get_zp_per_tensor(), # Old zero point
210+
rescale_qargs[i].get_zp_per_tensor(), # New zero point
211+
),
212+
from_node=node,
213+
)
214+
215+
node.replace_input_with(arg_node, rescale_node)
216+
modified = True
217+
218+
return modified
219+
220+
def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> bool:
221+
if "output_qparams" not in node.meta or len(node.meta["output_qparams"]) == 0:
222+
return False
223+
224+
qargs = get_output_qparams(node)
225+
assert len(qargs) == 1
226+
assert rescale_qargs is not None
227+
228+
qarg = qargs[0]
229+
if qarg.dtype != torch.int8:
230+
return False
231+
232+
users_copy = list(node.users)
233+
234+
with graph.inserting_after(node):
235+
rescale_node = create_node(
236+
graph,
237+
exir_ops.backend.tosa.RESCALE.default,
238+
(
239+
node,
240+
torch.int8,
241+
rescale_qargs.get_scale_per_tensor()
242+
/ qarg.get_scale_per_tensor(), # Old scale / new scale
243+
rescale_qargs.get_zp_per_tensor(), # Old zero point
244+
qarg.get_zp_per_tensor(), # New zero point
245+
),
246+
from_node=node,
247+
)
248+
249+
for user in users_copy:
250+
user.replace_input_with(node, rescale_node)
251+
252+
return True
253+
254+
def call(self, graph_module: GraphModule) -> PassResult:
255+
graph = graph_module.graph
256+
257+
modified = False
258+
for node in list(graph.nodes):
259+
node = cast(Node, node)
260+
261+
if node.op != "call_function" or node.target not in self.included_targets:
262+
continue
263+
264+
if "input_qparams" not in node.meta or len(node.meta["input_qparams"]) == 0:
265+
continue
266+
input_qparams = node.meta["input_qparams"]
267+
268+
inputs_rescale_qargs, output_rescale_qargs = self._get_rescale_qparams(
269+
node.target, input_qparams
270+
)
271+
272+
inputs_was_rescaled = self._rescale_inputs(
273+
graph, node, inputs_rescale_qargs
274+
)
275+
outputs_was_rescaled = False
276+
if inputs_was_rescaled:
277+
outputs_was_rescaled = self._rescale_outputs(
278+
graph, node, output_rescale_qargs
279+
)
280+
modified = True
281+
282+
# Update node metadata
283+
284+
if inputs_was_rescaled:
285+
assert len(inputs_rescale_qargs) == len(node.meta["input_qparams"])
286+
node.meta["input_qparams"] = inputs_rescale_qargs
287+
288+
if outputs_was_rescaled:
289+
assert len(node.meta["output_qparams"]) == 1
290+
node.meta["output_qparams"] = {0: output_rescale_qargs}
291+
292+
# If the output type is specified in the node, change it such
293+
# that it matches the subsequent rescale node(s) that this node
294+
# now has output edges to.
295+
if "dtype" in node.kwargs:
296+
set_node_arg(node, "dtype", torch.int32)
297+
298+
if modified:
299+
# Retrace the graph to update the fake tensor types
300+
graph_module = super().call(graph_module).graph_module
301+
graph_module.recompile()
302+
303+
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)