Skip to content

Commit 82a2324

Browse files
Arm backend: Move rescales from MUL visitor to pass (#15103)
Some TOSA ops do not support INT8 as inputs and outputs. Instead, only INT32 is supported as a whole number type. Prior to this patch, the MUL node visitor inserted rescale ops between the data types INT8 and INT32 before and after the operator such that it will accept its input and output. Change this by moving the insertion of the rescale ops to the pass `InsertRescaleInt32Pass`. This will further enable optimizations to the graph by fusing the rescale nodes. ### Test plan Test coverage is found the modified test_insert_rescale_i32_pass.py. Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Oscar Andersson <[email protected]>
1 parent 519c7ff commit 82a2324

File tree

3 files changed

+41
-113
lines changed

3 files changed

+41
-113
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import math
67
from copy import copy
78
from typing import cast, Dict, Optional, Set, Tuple, Type
89

@@ -94,6 +95,7 @@ class InsertRescaleInt32Pass(ArmPass):
9495
exir_ops.edge.aten.lt.Tensor,
9596
exir_ops.edge.aten.maximum.default,
9697
exir_ops.edge.aten.minimum.default,
98+
exir_ops.edge.aten.mul.Tensor,
9799
]
98100

99101
def _int32_qargs(self, s):
@@ -134,6 +136,15 @@ def _get_inputs_rescaled_qparams(
134136
qparams = {
135137
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
136138
}
139+
elif target in [
140+
exir_ops.edge.aten.mul.Tensor,
141+
]:
142+
# The input scales do not need to be adjusted for these ops; they
143+
# can remain the same.
144+
qparams = {
145+
i: self._int32_qargs(input_qparams[i].get_scale_per_tensor())
146+
for i in range(len(input_qparams))
147+
}
137148
else:
138149
raise ValueError(f"Not a valid target: {target}")
139150

@@ -162,6 +173,20 @@ def _get_output_qparams(
162173
]:
163174
# Output is bool for these ops and thus no qparams are present
164175
return None
176+
elif target in [exir_ops.edge.aten.mul.Tensor]:
177+
# Mul will cause the scales to also multiply; refer to the formula
178+
# where we compute the output scale S_2:
179+
#
180+
# (Q_2 - ZP_2) * S_2 == ((Q_0 - ZP_0) * S_0) * ((Q_1 - ZP_1) * S_1)
181+
#
182+
# yields:
183+
#
184+
# (Q_2 - ZP_2) == (Q_0 - ZP_0) * (Q_1 - ZP_1)
185+
# S_2 = S_0 * S_1
186+
output_scale = math.prod(
187+
(qp.get_scale_per_tensor() for qp in inputs_qparams.values())
188+
)
189+
return self._int32_qargs(output_scale)
165190
else:
166191
raise ValueError(f"Not a valid target: {target}")
167192

@@ -188,7 +213,7 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b
188213
modified = False
189214
for i in qargs:
190215
qp = qargs[i]
191-
if qp.dtype != torch.int8:
216+
if qp.dtype not in (torch.int8, torch.int16):
192217
continue
193218

194219
arg_node = args_copy[i]
@@ -227,7 +252,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
227252
assert rescale_qargs is not None
228253

229254
qarg = qargs[0]
230-
if qarg.dtype != torch.int8:
255+
if qarg.dtype not in (torch.int8, torch.int16):
231256
return False
232257

233258
users_copy = list(node.users)
@@ -238,7 +263,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
238263
exir_ops.backend.tosa.RESCALE.default,
239264
(
240265
node,
241-
torch.int8,
266+
qarg.dtype,
242267
rescale_qargs.get_scale_per_tensor()
243268
/ qarg.get_scale_per_tensor(), # Old scale / new scale
244269
rescale_qargs.get_zp_per_tensor(), # Old zero point

backends/arm/operators/op_mul.py

Lines changed: 6 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,8 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
import executorch.backends.arm.tosa.utils as tutils
1210
import torch
1311

14-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
15-
get_input_qparams,
16-
)
17-
1812
from executorch.backends.arm.operators.node_visitor import (
1913
NodeVisitor,
2014
register_node_visitor,
@@ -24,17 +18,17 @@
2418
validate_same_dtype,
2519
validate_valid_dtype,
2620
)
27-
from executorch.backends.arm.tosa import TosaSpecification
2821
from executorch.backends.arm.tosa.mapping import TosaArg
22+
from executorch.backends.arm.tosa.specification import TosaSpecification
2923

3024

3125
@register_node_visitor
32-
class MulVisitor_INT(NodeVisitor):
26+
class MulVisitor(NodeVisitor):
3327
target = "aten.mul.Tensor"
3428

3529
tosa_specs = [
30+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3631
TosaSpecification.create_from_string("TOSA-1.0+INT"),
37-
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
3832
]
3933

4034
def define_node(
@@ -52,105 +46,13 @@ def define_node(
5246
validate_valid_dtype(
5347
self.target,
5448
[*inputs, output],
55-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
49+
[ts.DType.INT32, ts.DType.FP32],
5650
output.tosa_spec,
5751
)
5852

59-
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
60-
input_A = inputs[0]
61-
input_B = inputs[1]
62-
input_qparams = get_input_qparams(node)
63-
input_A_qargs = input_qparams[0]
64-
input_B_qargs = input_qparams[1]
65-
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
66-
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
67-
68-
# Rescale inputs to INT32 with zp=0
69-
input_A_rescaled = tqutils.build_rescale_to_int32(
70-
tosa_graph,
71-
input_A,
72-
input_A_qargs.get_zp_per_tensor(),
73-
1.0,
74-
tosa_spec=self.tosa_spec,
75-
)
76-
input_B_rescaled = tqutils.build_rescale_to_int32(
77-
tosa_graph,
78-
input_B,
79-
input_B_qargs.get_zp_per_tensor(),
80-
1.0,
81-
tosa_spec=self.tosa_spec,
82-
)
83-
else:
84-
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
85-
# Non quantized input, natively support by TOSA.MUL
86-
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]
87-
88-
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
89-
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
90-
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
91-
else:
92-
# output.dtype == ts.DType.INT32 (non-quantized)
93-
mul_output = output
94-
95-
# Do the INT32 Mul
96-
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
97-
self._serialize_operator(
98-
node,
99-
tosa_graph,
100-
ts.TosaOp.Op().MUL,
101-
[input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"],
102-
[mul_output.name],
103-
)
104-
105-
if output.dtype == ts.DType.INT8:
106-
# Scale output back to 8 bit
107-
output_scale = (
108-
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
109-
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
110-
)
111-
tqutils.insert_rescale_op_to_int8(
112-
tosa_graph, mul_output, output_scale, node, self.tosa_spec
113-
)
114-
elif output.dtype == ts.DType.INT16:
115-
# Scale output back to 16 bit
116-
output_scale = (
117-
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
118-
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
119-
)
120-
tqutils.insert_rescale_op_to_int16(
121-
tosa_graph, mul_output, output_scale, node, self.tosa_spec
122-
)
123-
124-
125-
@register_node_visitor
126-
class MulVisitor_FP(MulVisitor_INT):
127-
# inheriting 'target' from INT class
128-
129-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
130-
131-
def define_node(
132-
self,
133-
node: torch.fx.Node,
134-
tosa_graph: Any,
135-
inputs: List[TosaArg],
136-
output: TosaArg,
137-
) -> None:
138-
139-
import serializer.tosa_serializer as ts # type: ignore
140-
141-
validate_num_inputs(self.target, inputs, 2)
142-
validate_same_dtype(self.target, [*inputs, output], ts)
143-
144-
if inputs[0].dtype == ts.DType.INT8:
145-
return super().define_node(node, tosa_graph, inputs, output)
146-
147-
input1, input2 = inputs
148-
14953
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
150-
self._serialize_operator(
151-
node,
152-
tosa_graph,
54+
tosa_graph.addOperator(
15355
ts.TosaOp.Op().MUL,
154-
[input1.name, input2.name, f"{node.name}_shift"],
56+
[inputs[0].name, inputs[1].name, f"{node.name}_shift"],
15557
[output.name],
15658
)

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ def __init__(self):
2222
super().__init__()
2323

2424
def forward(self, x, y):
25-
a = torch.maximum(x, y)
26-
b = torch.abs(a)
27-
c = a > b
28-
return c
25+
a = x * y
26+
b = torch.maximum(a, y)
27+
c = torch.abs(b)
28+
d = c > b
29+
return d
2930

3031
def get_inputs(self, dtype) -> input_t:
3132
if dtype == torch.float32:
@@ -45,8 +46,8 @@ def test_insert_rescales():
4546
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
4647
ops_after = {
4748
# "number of op nodes with i8 output" + "number of i8 node inputs"
48-
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2
49-
+ 5,
49+
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3
50+
+ 7,
5051
}
5152
pipeline = PassPipeline[input_t](
5253
module,

0 commit comments

Comments
 (0)