Skip to content

Commit 2cda2ff

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Move rescale ops out of comparison visitors
Some TOSA ops do not support INT8 as inputs and outputs. Instead, only INT32 is supported as a whole number type. Prior to this patch, affected node visitors inserted rescale ops between the data types INT8 and INT32 before and after the operator such that it will accept its input and output. Change this by moving the insertion of the rescale ops to a new pass called InsertRescaleInt32Pass. This will further enable optimizations to the graph by fusing the rescale nodes. Only comparison operators are handled in this patch; the remaining ones are left out to be done in another patch. Signed-off-by: Martin Lindström <[email protected]> Change-Id: I6bb8a10a0b453ae9fd8b8604d64cc5103a4da050
1 parent db8d04f commit 2cda2ff

File tree

9 files changed

+272
-73
lines changed

9 files changed

+272
-73
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from .insert_int32_casts_after_int64_placeholders import ( # noqa
8282
InsertInt32CastsAfterInt64PlaceholdersPass,
8383
)
84-
from .insert_rescales_pass import InsertRescalePass # noqa
84+
from .insert_rescales_pass import InsertRescaleInt32Pass, InsertRescalePass # noqa
8585
from .insert_table_ops import InsertTableOpsPass # noqa
8686
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
8787
from .match_arg_ranks_pass import MatchArgRanksPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
FuseEqualPlaceholdersPass,
8282
FuseQuantizedActivationPass,
8383
InsertInt32CastsAfterInt64PlaceholdersPass,
84+
InsertRescaleInt32Pass,
8485
InsertRescalePass,
8586
InsertTableOpsPass,
8687
MatchArgDtypePass,
@@ -214,6 +215,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
214215
self.add_pass(ToTosaMemoryFormatPass(exported_program))
215216
self.add_pass(RemoveNoopPass())
216217
self.add_pass(InsertRescalePass())
218+
self.add_pass(InsertRescaleInt32Pass())
217219

218220
self.validate_constraints_mandatory()
219221
return self._transform(exported_program.graph_module)

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 189 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from copy import copy
7-
from typing import cast, Set, Type
7+
from typing import cast, Dict, Optional, Set, Tuple, Type
88

9-
from executorch.backends.arm._passes.arm_pass_utils import create_node
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg
12+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13+
get_output_qparams,
14+
)
1015
from executorch.backends.arm._passes.quant_args import QuantArgs
1116
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1217
from executorch.exir.dialects._ops import ops as exir_ops
@@ -65,3 +70,185 @@ def call(self, graph_module: GraphModule) -> PassResult:
6570
graph_module = super().call(graph_module).graph_module
6671
graph_module.recompile()
6772
return PassResult(graph_module, modified)
73+
74+
75+
class InsertRescaleInt32Pass(ArmPass):
76+
"""
77+
Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
78+
quantized implementations. This pass treats such operator nodes by
79+
inserting rescale ops before and after them if needed. Note that extra logic
80+
that handles the scales and zero points must be in place because the affected
81+
TOSA have naive implementations that do not account for the quantization
82+
parameters.
83+
"""
84+
85+
_passes_required_after: Set[Type[ExportPass]] = set()
86+
87+
included_targets = [
88+
exir_ops.edge.aten.eq.Tensor,
89+
exir_ops.edge.aten.ge.Tensor,
90+
exir_ops.edge.aten.gt.Tensor,
91+
exir_ops.edge.aten.le.Tensor,
92+
exir_ops.edge.aten.lt.Tensor,
93+
]
94+
95+
def _get_rescale_qparams(
96+
self, target, input_qparams: Dict[int, QuantArgs]
97+
) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]:
98+
"""
99+
Get the quantization parameters of the Int32 inputs/outputs that will
100+
surround the node.
101+
"""
102+
103+
# Helper creator function for Int32-based QuantArgs
104+
def int32_qargs(s):
105+
return QuantArgs(
106+
scale=s,
107+
zp=0,
108+
qmin=torch.iinfo(torch.int32).min,
109+
qmax=torch.iinfo(torch.int32).max,
110+
dtype=torch.int32,
111+
)
112+
113+
if target in [
114+
exir_ops.edge.aten.eq.Tensor,
115+
exir_ops.edge.aten.ge.Tensor,
116+
exir_ops.edge.aten.gt.Tensor,
117+
exir_ops.edge.aten.le.Tensor,
118+
exir_ops.edge.aten.lt.Tensor,
119+
]:
120+
# Use the lowest scale of the operands since that yields the best numerical precision.
121+
min_scale = min(
122+
[qp.get_scale_per_tensor() for qp in input_qparams.values()]
123+
)
124+
inputs_rescale_qparams = {
125+
i: int32_qargs(min_scale) for i in range(len(input_qparams))
126+
}
127+
128+
# Return None as output quant args since the output is not quantized (bool dtype)
129+
return (inputs_rescale_qparams, None)
130+
else:
131+
raise ValueError(f"Unknown target: {target}")
132+
133+
def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool:
134+
qargs = node.meta["input_qparams"]
135+
136+
args_copy = list(node.args)
137+
seen_args = set()
138+
modified = False
139+
for i in qargs:
140+
qp = qargs[i]
141+
if qp.dtype != torch.int8:
142+
continue
143+
144+
arg_node = args_copy[i]
145+
if arg_node in seen_args:
146+
continue
147+
seen_args.add(arg_node)
148+
149+
with graph.inserting_after(arg_node):
150+
rescale_node = create_node(
151+
graph,
152+
exir_ops.backend.tosa.RESCALE.default,
153+
(
154+
arg_node,
155+
torch.int32,
156+
qp.get_scale_per_tensor()
157+
/ rescale_qargs[
158+
i
159+
].get_scale_per_tensor(), # Old scale / new scale
160+
qp.get_zp_per_tensor(), # Old zero point
161+
rescale_qargs[i].get_zp_per_tensor(), # New zero point
162+
),
163+
from_node=node,
164+
)
165+
166+
node.replace_input_with(arg_node, rescale_node)
167+
modified = True
168+
169+
return modified
170+
171+
def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> bool:
172+
if "output_qparams" not in node.meta or len(node.meta["output_qparams"]) == 0:
173+
return False
174+
175+
qargs = get_output_qparams(node)
176+
assert len(qargs) == 1
177+
assert rescale_qargs is not None
178+
179+
qarg = qargs[0]
180+
if qarg.dtype != torch.int8:
181+
return False
182+
183+
users_copy = list(node.users)
184+
185+
with graph.inserting_after(node):
186+
rescale_node = create_node(
187+
graph,
188+
exir_ops.backend.tosa.RESCALE.default,
189+
(
190+
node,
191+
torch.int8,
192+
rescale_qargs.get_scale_per_tensor()
193+
/ qarg.get_scale_per_tensor(), # Old scale / new scale
194+
rescale_qargs.get_zp_per_tensor(), # Old zero point
195+
qarg.get_zp_per_tensor(), # New zero point
196+
),
197+
from_node=node,
198+
)
199+
200+
for user in users_copy:
201+
user.replace_input_with(node, rescale_node)
202+
203+
return True
204+
205+
def call(self, graph_module: GraphModule) -> PassResult:
206+
graph = graph_module.graph
207+
208+
modified = False
209+
for node in list(graph.nodes):
210+
node = cast(Node, node)
211+
212+
if node.op != "call_function" or node.target not in self.included_targets:
213+
continue
214+
215+
if "input_qparams" not in node.meta or len(node.meta["input_qparams"]) == 0:
216+
continue
217+
input_qparams = node.meta["input_qparams"]
218+
219+
inputs_rescale_qargs, output_rescale_qargs = self._get_rescale_qparams(
220+
node.target, input_qparams
221+
)
222+
223+
inputs_was_rescaled = self._rescale_inputs(
224+
graph, node, inputs_rescale_qargs
225+
)
226+
outputs_was_rescaled = False
227+
if inputs_was_rescaled:
228+
outputs_was_rescaled = self._rescale_outputs(
229+
graph, node, output_rescale_qargs
230+
)
231+
modified = True
232+
233+
# Update node metadata
234+
235+
if inputs_was_rescaled:
236+
assert len(inputs_rescale_qargs) == len(node.meta["input_qparams"])
237+
node.meta["input_qparams"] = inputs_rescale_qargs
238+
239+
if outputs_was_rescaled:
240+
assert len(node.meta["output_qparams"]) == 1
241+
node.meta["output_qparams"] = {0: output_rescale_qargs}
242+
243+
# If the output type is specified in the node, change it such
244+
# that it matches the subsequent rescale node(s) that this node
245+
# now has output edges to.
246+
if "dtype" in node.kwargs:
247+
set_node_arg(node, "dtype", torch.int32)
248+
249+
if modified:
250+
# Retrace the graph to update the fake tensor types
251+
graph_module = super().call(graph_module).graph_module
252+
graph_module.recompile()
253+
254+
return PassResult(graph_module, modified)

backends/arm/operators/op_eq.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
1210
from executorch.backends.arm.operators.node_visitor import (
1311
NodeVisitor,
1412
register_node_visitor,
@@ -56,23 +54,12 @@ def define_node(
5654
)
5755
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
5856

59-
input_nodes = inputs
60-
# Handle quantization
61-
if inputs[0].dtype == ts.DType.INT8:
62-
# Rescale inputs to 32 bit
63-
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
64-
tosa_graph, inputs, node, self.tosa_spec
65-
)
66-
67-
# Update IO
68-
input_nodes = rescaled_inputs
69-
7057
# Do the equal comparison
7158
self._serialize_operator(
7259
node,
7360
tosa_graph,
7461
ts.TosaOp.Op().EQUAL,
75-
[input_nodes[0].name, input_nodes[1].name],
62+
[inputs[0].name, inputs[1].name],
7663
[output.name],
7764
None,
7865
)

backends/arm/operators/op_ge.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
1210
from executorch.backends.arm.operators.node_visitor import (
1311
NodeVisitor,
1412
register_node_visitor,
@@ -56,22 +54,11 @@ def define_node(
5654
)
5755
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
5856

59-
input_nodes = inputs
60-
# Handle quantization
61-
if inputs[0].dtype == ts.DType.INT8:
62-
# Rescale inputs to 32 bit
63-
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
64-
tosa_graph, inputs, node, self.tosa_spec
65-
)
66-
67-
# Update IO
68-
input_nodes = rescaled_inputs
69-
7057
self._serialize_operator(
7158
node,
7259
tosa_graph,
7360
ts.TosaOp.Op().GREATER_EQUAL,
74-
[input_nodes[0].name, input_nodes[1].name],
61+
[inputs[0].name, inputs[1].name],
7562
[output.name],
7663
None,
7764
)

backends/arm/operators/op_gt.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
1210
from executorch.backends.arm.operators.node_visitor import (
1311
NodeVisitor,
1412
register_node_visitor,
@@ -56,22 +54,11 @@ def define_node(
5654
)
5755
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
5856

59-
input_nodes = inputs
60-
# Handle quantization
61-
if inputs[0].dtype == ts.DType.INT8:
62-
# Rescale inputs to 32 bit
63-
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
64-
tosa_graph, inputs, node, self.tosa_spec
65-
)
66-
67-
# Update IO
68-
input_nodes = rescaled_inputs
69-
7057
self._serialize_operator(
7158
node,
7259
tosa_graph,
7360
ts.TosaOp.Op().GREATER,
74-
[input_nodes[0].name, input_nodes[1].name],
61+
[inputs[0].name, inputs[1].name],
7562
[output.name],
7663
None,
7764
)

backends/arm/operators/op_le.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
1210
from executorch.backends.arm.operators.node_visitor import (
1311
NodeVisitor,
1412
register_node_visitor,
@@ -56,22 +54,11 @@ def define_node(
5654
)
5755
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
5856

59-
input_nodes = inputs
60-
# Handle quantization
61-
if inputs[0].dtype == ts.DType.INT8:
62-
# Rescale inputs to 32 bit
63-
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
64-
tosa_graph, inputs, node, self.tosa_spec
65-
)
66-
67-
# Update IO
68-
input_nodes = rescaled_inputs
69-
7057
self._serialize_operator(
7158
node,
7259
tosa_graph,
7360
ts.TosaOp.Op().GREATER_EQUAL,
74-
[input_nodes[1].name, input_nodes[0].name],
61+
[inputs[1].name, inputs[0].name],
7562
[output.name],
7663
None,
7764
)

backends/arm/operators/op_lt.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
1210
from executorch.backends.arm.operators.node_visitor import (
1311
NodeVisitor,
1412
register_node_visitor,
@@ -56,22 +54,11 @@ def define_node(
5654
)
5755
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
5856

59-
input_nodes = inputs
60-
# Handle quantization
61-
if inputs[0].dtype == ts.DType.INT8:
62-
# Rescale inputs to 32 bit
63-
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
64-
tosa_graph, inputs, node, self.tosa_spec
65-
)
66-
67-
# Update IO
68-
input_nodes = rescaled_inputs
69-
7057
self._serialize_operator(
7158
node,
7259
tosa_graph,
7360
ts.TosaOp.Op().GREATER,
74-
[input_nodes[1].name, input_nodes[0].name],
61+
[inputs[1].name, inputs[0].name],
7562
[output.name],
7663
None,
7764
)

0 commit comments

Comments
 (0)