Skip to content

Commit 25cf3e9

Browse files
Martin Lindströmoscarandersson8218
authored andcommitted
Arm backend: Move rescales from ADD visitor to pass
Move the insertion of INT8/INT32 RESCALE ops from the ADD node visitor to the pass InsertRescaleInt32Pass. This is in practice a refactoring patch, but still the output TOSA file becomes different enough to cause an Ethos-U55 test to fail in test_conv_relu_residual_add.py. However, this issue was fixed in https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/commit/642f7517d3a6bd053032e1942822f6e38ccd546f so we temporarily set the failing test to xfail until the version of Ethos-U Vela compiler depended on is bumped to one that includes the fix. Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Oscar Andersson <[email protected]> Change-Id: I90cd9ab5a296911a228b0080008e4c65ba773f7c
1 parent de56c81 commit 25cf3e9

File tree

4 files changed

+56
-115
lines changed

4 files changed

+56
-115
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
7676

7777

7878
class InsertRescaleInt32Pass(ArmPass):
79-
"""
80-
Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
79+
"""Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
8180
quantized implementations. This pass treats such operator nodes by
82-
inserting rescale ops before and after them if needed. Note that extra logic
83-
that handles the scales and zero points must be in place because the affected
84-
TOSA have naive implementations that do not account for the quantization
85-
parameters.
81+
inserting rescale ops before and after them if needed. Note that extra
82+
logic that handles the scales and zero points are in place here because the
83+
affected TOSA ops have naive implementations that do not account for the
84+
quantization parameters.
8685
"""
8786

8887
# SUM must be decomposed after this pass to prevent insertion of RESCALE
@@ -93,6 +92,7 @@ class InsertRescaleInt32Pass(ArmPass):
9392

9493
included_targets = [
9594
exir_ops.edge.aten.abs.default,
95+
exir_ops.edge.aten.add.Tensor,
9696
exir_ops.edge.aten.eq.Tensor,
9797
exir_ops.edge.aten.ge.Tensor,
9898
exir_ops.edge.aten.gt.Tensor,
@@ -142,6 +142,33 @@ def _get_inputs_rescaled_qparams(
142142
qparams = {
143143
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
144144
}
145+
elif target in [
146+
exir_ops.edge.aten.add.Tensor,
147+
]:
148+
if input_qparams[0].dtype != input_qparams[1].dtype:
149+
raise ValueError(
150+
"Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}"
151+
)
152+
153+
# We are handling two INT8 or two INT16 numbers. For INT8, if the
154+
# zero point is non-null, the result will be in the range [-255;
155+
# 255], therefore we need 9 bits for the result. We have a 32-bit
156+
# accumulator, so we can divide the scale by (1 << 20) which is
157+
# equivalent to shifting the INT8 operands 20 bits to the left
158+
# before rescaling them both to 2 * max(lhs, rhs).
159+
#
160+
# For INT16, similary logic can be applied, but we instead end up
161+
# with a left shift of 12.
162+
lhs_scale, rhs_scale = (
163+
qp.get_scale_per_tensor() for qp in input_qparams.values()
164+
)
165+
max_scale_2x = 2 * max(lhs_scale, rhs_scale)
166+
167+
# Select shift based on input dtype.
168+
shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20
169+
170+
scale = max_scale_2x / (1 << shift_bits)
171+
qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))}
145172
elif target in [
146173
exir_ops.edge.aten.mul.Tensor,
147174
exir_ops.edge.aten.sum.dim_IntList,
@@ -168,6 +195,7 @@ def _get_output_qparams(
168195
exir_ops.edge.aten.maximum.default,
169196
exir_ops.edge.aten.minimum.default,
170197
exir_ops.edge.aten.sum.dim_IntList,
198+
exir_ops.edge.aten.add.Tensor,
171199
]:
172200
# The op has not altered the scale; the output scale is equal to
173201
# the operands' scales.

backends/arm/operators/op_add.py

Lines changed: 7 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
import executorch.backends.arm.tosa.utils as tutils
1210
import tosa_serializer as ts
1311

1412
from executorch.backends.arm.operators.node_visitor import (
@@ -20,22 +18,20 @@
2018
validate_same_dtype,
2119
validate_valid_dtype,
2220
)
23-
from executorch.backends.arm.tosa import TosaSpecification
2421
from executorch.backends.arm.tosa.mapping import TosaArg
22+
from executorch.backends.arm.tosa.specification import TosaSpecification
2523
from torch.fx import Node
2624

2725

2826
@register_node_visitor
29-
class AddVisitor_INT(NodeVisitor):
27+
class AddVisitor(NodeVisitor):
3028
target = "aten.add.Tensor"
3129

3230
tosa_specs = [
3331
TosaSpecification.create_from_string("TOSA-1.0+INT"),
32+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3433
]
3534

36-
def __init__(self, *args):
37-
super().__init__(*args)
38-
3935
def define_node(
4036
self,
4137
node: Node,
@@ -45,113 +41,21 @@ def define_node(
4541
) -> None:
4642
validate_num_inputs(self.target, inputs, 2)
4743
validate_same_dtype(self.target, [*inputs, output], ts)
48-
valid_dtypes = []
49-
if self.tosa_spec.support_integer():
50-
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
51-
if self.tosa_spec.support_float():
52-
valid_dtypes.extend([ts.DType.INT32])
53-
5444
validate_valid_dtype(
5545
self.target,
5646
[*inputs, output],
57-
valid_dtypes,
47+
[ts.DType.INT32, ts.DType.FP32],
5848
output.tosa_spec,
5949
)
60-
scale_back = 1.0
61-
if inputs[0].dtype == ts.DType.INT8:
62-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
63-
tosa_graph, inputs, node, self.tosa_spec
64-
)
65-
elif inputs[0].dtype == ts.DType.INT16:
66-
rescaled_inputs, scale_back = (
67-
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
68-
tosa_graph, inputs, node, self.tosa_spec
69-
)
70-
)
71-
else:
72-
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
73-
# Non quantized input, natively support by TOSA.ADD
74-
rescaled_inputs = inputs
75-
76-
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
77-
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
78-
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
79-
else:
80-
# output.dtype == ts.DType.INT16 or ts.DType.INT32
81-
add_output = output
8250

83-
input1, input2 = rescaled_inputs
8451
attr = ts.TosaSerializerAttribute()
8552
attr.AddAttribute()
86-
# Do the INT32 Add
53+
8754
self._serialize_operator(
8855
node,
8956
tosa_graph,
9057
ts.Op.ADD,
91-
[input1.name, input2.name],
92-
[add_output.name],
58+
[inputs[0].name, inputs[1].name],
59+
[output.name],
9360
attr,
9461
)
95-
96-
if output.dtype == ts.DType.INT8:
97-
# Scale output back to 8 bit
98-
# pyre-ignore
99-
tqutils.insert_rescale_op_to_int8(
100-
tosa_graph,
101-
add_output,
102-
scale_back,
103-
node,
104-
compute_rescale=False,
105-
tosa_spec=self.tosa_spec,
106-
) # type: ignore[possibly-undefined]
107-
elif output.dtype == ts.DType.INT16:
108-
tqutils.insert_rescale_op_to_int16(
109-
tosa_graph,
110-
add_output,
111-
scale_back,
112-
node,
113-
compute_rescale=False,
114-
tosa_spec=self.tosa_spec,
115-
) # type: ignore[possibly-undefined]
116-
117-
118-
@register_node_visitor
119-
class AddVisitor_FP(AddVisitor_INT):
120-
# inheriting 'target' from INT class
121-
122-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
123-
124-
def __init__(self, *args):
125-
super().__init__(*args)
126-
127-
def define_node(
128-
self,
129-
node: Node,
130-
tosa_graph: Any,
131-
inputs: List[TosaArg],
132-
output: TosaArg,
133-
) -> None:
134-
validate_num_inputs(self.target, inputs, 2)
135-
validate_same_dtype(self.target, [*inputs, output], ts)
136-
137-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]:
138-
# Call the inherited define_node for handling integers
139-
super().define_node(node, tosa_graph, inputs, output)
140-
else:
141-
# FP32 Add lowering
142-
validate_valid_dtype(
143-
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
144-
)
145-
146-
input1, input2 = inputs
147-
attr = ts.TosaSerializerAttribute()
148-
attr.AddAttribute()
149-
# FP lowering
150-
self._serialize_operator(
151-
node,
152-
tosa_graph,
153-
ts.Op.ADD,
154-
[input1.name, input2.name],
155-
[output.name],
156-
attr,
157-
)

backends/arm/test/misc/test_conv_relu_residual_add.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def test_tosa_INT(per_channel_quantization):
7676
pipeline.run()
7777

7878

79+
# TODO: Xfail until the Ethos-U Vela compiler ships commit
80+
# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that
81+
# causes this test to fail.
82+
@pytest.mark.xfail(
83+
reason=("Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f"),
84+
strict=True,
85+
)
7986
@pytest.mark.slow
8087
@common.XfailIfNoCorstone300
8188
@common.parametrize("per_channel_quantization", quant_test_data)

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ class MultipleOpsModel(torch.nn.Module):
1919
input_t = Tuple[torch.Tensor, torch.Tensor]
2020

2121
def forward(self, x, y):
22-
a = x * y
23-
b = torch.maximum(a, y)
24-
c = torch.abs(b)
25-
d = c > b
26-
return d
22+
a = x + y
23+
b = x * a
24+
c = torch.maximum(a, b)
25+
d = torch.abs(b)
26+
e = c + d
27+
f = e > a
28+
return f
2729

2830
def get_inputs(self, dtype) -> input_t:
2931
if dtype == torch.float32:
@@ -38,7 +40,7 @@ def get_inputs(self, dtype) -> input_t:
3840

3941
def get_num_expected_rescales(self):
4042
# "number of op nodes with i8 output" + "number of i8 node inputs"
41-
return 3 + 7
43+
return 5 + 11
4244

4345

4446
class SumModel(torch.nn.Module):

0 commit comments

Comments
 (0)