Skip to content

Commit 21876ea

Browse files
Martin Lindströmoscarandersson8218
authored andcommitted
Arm backend: Move rescales from MUL visitor to pass
Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Oscar Andersson <[email protected]> Change-Id: Ie6e019d21ae1868512eb81e4a0f97c3484cd3962
1 parent 57a7903 commit 21876ea

File tree

3 files changed

+41
-113
lines changed

3 files changed

+41
-113
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import math
67
from copy import copy
78
from typing import cast, Dict, Optional, Set, Tuple, Type
89

@@ -93,6 +94,7 @@ class InsertRescaleInt32Pass(ArmPass):
9394
exir_ops.edge.aten.lt.Tensor,
9495
exir_ops.edge.aten.maximum.default,
9596
exir_ops.edge.aten.minimum.default,
97+
exir_ops.edge.aten.mul.Tensor,
9698
]
9799

98100
def _int32_qargs(self, s):
@@ -133,6 +135,15 @@ def _get_inputs_rescaled_qparams(
133135
qparams = {
134136
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
135137
}
138+
elif target in [
139+
exir_ops.edge.aten.mul.Tensor,
140+
]:
141+
# The input scales do not need to be adjusted for these ops; they
142+
# can remain the same.
143+
qparams = {
144+
i: self._int32_qargs(input_qparams[i].get_scale_per_tensor())
145+
for i in range(len(input_qparams))
146+
}
136147
else:
137148
raise ValueError(f"Not a valid target: {target}")
138149

@@ -161,6 +172,20 @@ def _get_output_qparams(
161172
]:
162173
# Output is bool for these ops and thus no qparams are present
163174
return None
175+
elif target in [exir_ops.edge.aten.mul.Tensor]:
176+
# Mul will cause the scales to also multiply; refer to the formula
177+
# where we compute the output scale S_2:
178+
#
179+
# (Q_2 - ZP_2) * S_2 == ((Q_0 - ZP_0) * S_0) * ((Q_1 - ZP_1) * S_1)
180+
#
181+
# yields:
182+
#
183+
# (Q_2 - ZP_2) == (Q_0 - ZP_0) * (Q_1 - ZP_1)
184+
# S_2 = S_0 * S_1
185+
output_scale = math.prod(
186+
(qp.get_scale_per_tensor() for qp in inputs_qparams.values())
187+
)
188+
return self._int32_qargs(output_scale)
164189
else:
165190
raise ValueError(f"Not a valid target: {target}")
166191

@@ -187,7 +212,7 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b
187212
modified = False
188213
for i in qargs:
189214
qp = qargs[i]
190-
if qp.dtype != torch.int8:
215+
if qp.dtype not in (torch.int8, torch.int16):
191216
continue
192217

193218
arg_node = args_copy[i]
@@ -226,7 +251,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
226251
assert rescale_qargs is not None
227252

228253
qarg = qargs[0]
229-
if qarg.dtype != torch.int8:
254+
if qarg.dtype not in (torch.int8, torch.int16):
230255
return False
231256

232257
users_copy = list(node.users)
@@ -237,7 +262,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
237262
exir_ops.backend.tosa.RESCALE.default,
238263
(
239264
node,
240-
torch.int8,
265+
qarg.dtype,
241266
rescale_qargs.get_scale_per_tensor()
242267
/ qarg.get_scale_per_tensor(), # Old scale / new scale
243268
rescale_qargs.get_zp_per_tensor(), # Old zero point

backends/arm/operators/op_mul.py

Lines changed: 6 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,8 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
import executorch.backends.arm.tosa.utils as tutils
1210
import torch
1311

14-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
15-
get_input_qparams,
16-
)
17-
1812
from executorch.backends.arm.operators.node_visitor import (
1913
NodeVisitor,
2014
register_node_visitor,
@@ -24,17 +18,17 @@
2418
validate_same_dtype,
2519
validate_valid_dtype,
2620
)
27-
from executorch.backends.arm.tosa import TosaSpecification
2821
from executorch.backends.arm.tosa.mapping import TosaArg
22+
from executorch.backends.arm.tosa.specification import TosaSpecification
2923

3024

3125
@register_node_visitor
32-
class MulVisitor_INT(NodeVisitor):
26+
class MulVisitor(NodeVisitor):
3327
target = "aten.mul.Tensor"
3428

3529
tosa_specs = [
30+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3631
TosaSpecification.create_from_string("TOSA-1.0+INT"),
37-
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
3832
]
3933

4034
def define_node(
@@ -52,105 +46,13 @@ def define_node(
5246
validate_valid_dtype(
5347
self.target,
5448
[*inputs, output],
55-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
49+
[ts.DType.INT32, ts.DType.FP32],
5650
output.tosa_spec,
5751
)
5852

59-
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
60-
input_A = inputs[0]
61-
input_B = inputs[1]
62-
input_qparams = get_input_qparams(node)
63-
input_A_qargs = input_qparams[0]
64-
input_B_qargs = input_qparams[1]
65-
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
66-
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
67-
68-
# Rescale inputs to INT32 with zp=0
69-
input_A_rescaled = tqutils.build_rescale_to_int32(
70-
tosa_graph,
71-
input_A,
72-
input_A_qargs.get_zp_per_tensor(),
73-
1.0,
74-
tosa_spec=self.tosa_spec,
75-
)
76-
input_B_rescaled = tqutils.build_rescale_to_int32(
77-
tosa_graph,
78-
input_B,
79-
input_B_qargs.get_zp_per_tensor(),
80-
1.0,
81-
tosa_spec=self.tosa_spec,
82-
)
83-
else:
84-
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
85-
# Non quantized input, natively support by TOSA.MUL
86-
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]
87-
88-
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
89-
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
90-
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
91-
else:
92-
# output.dtype == ts.DType.INT32 (non-quantized)
93-
mul_output = output
94-
95-
# Do the INT32 Mul
96-
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
97-
self._serialize_operator(
98-
node,
99-
tosa_graph,
100-
ts.TosaOp.Op().MUL,
101-
[input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"],
102-
[mul_output.name],
103-
)
104-
105-
if output.dtype == ts.DType.INT8:
106-
# Scale output back to 8 bit
107-
output_scale = (
108-
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
109-
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
110-
)
111-
tqutils.insert_rescale_op_to_int8(
112-
tosa_graph, mul_output, output_scale, node, self.tosa_spec
113-
)
114-
elif output.dtype == ts.DType.INT16:
115-
# Scale output back to 16 bit
116-
output_scale = (
117-
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
118-
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
119-
)
120-
tqutils.insert_rescale_op_to_int16(
121-
tosa_graph, mul_output, output_scale, node, self.tosa_spec
122-
)
123-
124-
125-
@register_node_visitor
126-
class MulVisitor_FP(MulVisitor_INT):
127-
# inheriting 'target' from INT class
128-
129-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
130-
131-
def define_node(
132-
self,
133-
node: torch.fx.Node,
134-
tosa_graph: Any,
135-
inputs: List[TosaArg],
136-
output: TosaArg,
137-
) -> None:
138-
139-
import serializer.tosa_serializer as ts # type: ignore
140-
141-
validate_num_inputs(self.target, inputs, 2)
142-
validate_same_dtype(self.target, [*inputs, output], ts)
143-
144-
if inputs[0].dtype == ts.DType.INT8:
145-
return super().define_node(node, tosa_graph, inputs, output)
146-
147-
input1, input2 = inputs
148-
14953
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
150-
self._serialize_operator(
151-
node,
152-
tosa_graph,
54+
tosa_graph.addOperator(
15355
ts.TosaOp.Op().MUL,
154-
[input1.name, input2.name, f"{node.name}_shift"],
56+
[inputs[0].name, inputs[1].name, f"{node.name}_shift"],
15557
[output.name],
15658
)

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ def __init__(self):
2222
super().__init__()
2323

2424
def forward(self, x, y):
25-
a = torch.maximum(x, y)
26-
b = torch.abs(a)
27-
c = a > b
28-
return c
25+
a = x * y
26+
b = torch.maximum(a, y)
27+
c = torch.abs(b)
28+
d = c > b
29+
return d
2930

3031
def get_inputs(self, dtype) -> input_t:
3132
if dtype == torch.float32:
@@ -45,8 +46,8 @@ def test_insert_rescales():
4546
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
4647
ops_after = {
4748
# "number of op nodes with i8 output" + "number of i8 node inputs"
48-
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2
49-
+ 5,
49+
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3
50+
+ 7,
5051
}
5152
pipeline = PassPipeline[input_t](
5253
module,

0 commit comments

Comments
 (0)