Skip to content

Commit a647bc3

Browse files
Martin Lindströmoscarandersson8218
authored andcommitted
Arm backend: Move rescales from ABS visitor to pass
Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Oscar Andersson <[email protected]> Change-Id: I62fdc5bea75361d6c32711968bdc1c9d03677ccc
1 parent 2cda2ff commit a647bc3

File tree

3 files changed

+76
-108
lines changed

3 files changed

+76
-108
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -85,50 +85,93 @@ class InsertRescaleInt32Pass(ArmPass):
8585
_passes_required_after: Set[Type[ExportPass]] = set()
8686

8787
included_targets = [
88+
exir_ops.edge.aten.abs.default,
8889
exir_ops.edge.aten.eq.Tensor,
8990
exir_ops.edge.aten.ge.Tensor,
9091
exir_ops.edge.aten.gt.Tensor,
9192
exir_ops.edge.aten.le.Tensor,
9293
exir_ops.edge.aten.lt.Tensor,
9394
]
9495

95-
def _get_rescale_qparams(
96+
def _int32_qargs(self, s):
97+
"""Helper creator function for INT32-based QuantArgs"""
98+
99+
return QuantArgs(
100+
scale=s,
101+
zp=0,
102+
qmin=torch.iinfo(torch.int32).min,
103+
qmax=torch.iinfo(torch.int32).max,
104+
dtype=torch.int32,
105+
)
106+
107+
def _get_inputs_rescaled_qparams(
96108
self, target, input_qparams: Dict[int, QuantArgs]
97-
) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]:
98-
"""
99-
Get the quantization parameters of the Int32 inputs/outputs that will
100-
surround the node.
101-
"""
109+
) -> Dict[int, QuantArgs]:
110+
"""Get the qparams for the INT32 operands to the op ``target``
102111
103-
# Helper creator function for Int32-based QuantArgs
104-
def int32_qargs(s):
105-
return QuantArgs(
106-
scale=s,
107-
zp=0,
108-
qmin=torch.iinfo(torch.int32).min,
109-
qmax=torch.iinfo(torch.int32).max,
110-
dtype=torch.int32,
111-
)
112+
Inputs to the INT32-based operator must be rescaled from INT8 to INT32.
113+
This function computes the ``QuantArgs`` for each of the operands and returns
114+
it as a dict, mapping tensor index to ``QuantArgs``.
115+
"""
112116

113117
if target in [
118+
exir_ops.edge.aten.abs.default,
114119
exir_ops.edge.aten.eq.Tensor,
115120
exir_ops.edge.aten.ge.Tensor,
116121
exir_ops.edge.aten.gt.Tensor,
117122
exir_ops.edge.aten.le.Tensor,
118123
exir_ops.edge.aten.lt.Tensor,
119124
]:
120-
# Use the lowest scale of the operands since that yields the best numerical precision.
125+
# For these ops, use the smallest scale among the INT8 operands.
121126
min_scale = min(
122127
[qp.get_scale_per_tensor() for qp in input_qparams.values()]
123128
)
124-
inputs_rescale_qparams = {
125-
i: int32_qargs(min_scale) for i in range(len(input_qparams))
129+
qparams = {
130+
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
126131
}
132+
else:
133+
raise ValueError(f"Not a valid target: {target}")
134+
135+
return qparams
136+
137+
def _get_output_qparams(
138+
self, target, inputs_qparams: Dict[int, QuantArgs]
139+
) -> Optional[QuantArgs]:
140+
"""Given an op ``target`` and the ``QuantArgs`` for each of its inputs, compute
141+
the scale of the output based on how the operator itself affects it."""
127142

128-
# Return None as output quant args since the output is not quantized (bool dtype)
129-
return (inputs_rescale_qparams, None)
143+
if target in [
144+
exir_ops.edge.aten.abs.default,
145+
]:
146+
# The op has not altered the scale; the output scale is equal to
147+
# the operands' scales.
148+
return self._int32_qargs(inputs_qparams[0].get_scale_per_tensor())
149+
elif target in [
150+
exir_ops.edge.aten.eq.Tensor,
151+
exir_ops.edge.aten.ge.Tensor,
152+
exir_ops.edge.aten.gt.Tensor,
153+
exir_ops.edge.aten.le.Tensor,
154+
exir_ops.edge.aten.lt.Tensor,
155+
]:
156+
# Output is bool for these ops and thus no qparams are present
157+
return None
130158
else:
131-
raise ValueError(f"Unknown target: {target}")
159+
raise ValueError(f"Not a valid target: {target}")
160+
161+
def _get_rescale_qparams(
162+
self, target, input_qparams: Dict[int, QuantArgs]
163+
) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]:
164+
"""
165+
Get the quantization parameters of the INT32 inputs/outputs that will
166+
surround the node after the new RESCALE ops have been inserted.
167+
"""
168+
169+
inputs_rescaled_qparams = self._get_inputs_rescaled_qparams(
170+
target, input_qparams
171+
)
172+
output_qparams = self._get_output_qparams(target, inputs_rescaled_qparams)
173+
174+
return (inputs_rescaled_qparams, output_qparams)
132175

133176
def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool:
134177
qargs = node.meta["input_qparams"]

backends/arm/operators/op_abs.py

Lines changed: 7 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
# pyre-unsafe
77
from typing import Any, List
88

9-
import executorch.backends.arm.tosa.quant_utils as tqutils
10-
import executorch.backends.arm.tosa.utils as tutils
11-
129
from executorch.backends.arm.operators.node_visitor import (
1310
NodeVisitor,
1411
register_node_visitor,
@@ -18,22 +15,20 @@
1815
validate_same_dtype,
1916
validate_valid_dtype,
2017
)
21-
from executorch.backends.arm.tosa import TosaSpecification
2218
from executorch.backends.arm.tosa.mapping import TosaArg
19+
from executorch.backends.arm.tosa.specification import TosaSpecification
2320
from torch.fx import Node
2421

2522

2623
@register_node_visitor
27-
class AbsVisitor_INT(NodeVisitor):
24+
class AbsVisitor(NodeVisitor):
2825
target = "aten.abs.default"
2926

3027
tosa_specs = [
3128
TosaSpecification.create_from_string("TOSA-1.0+INT"),
29+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3230
]
3331

34-
def __init__(self, *args):
35-
super().__init__(*args)
36-
3732
def define_node(
3833
self,
3934
node: Node,
@@ -47,89 +42,18 @@ def define_node(
4742
validate_num_inputs(self.target, inputs, 1)
4843
validate_same_dtype(self.target, [*inputs, output], ts)
4944

50-
# Handle int8 (quantized) and int32
5145
validate_valid_dtype(
5246
self.target,
5347
[*inputs, output],
54-
[ts.DType.INT8, ts.DType.INT32],
48+
[ts.DType.INT32, ts.DType.FP32],
5549
output.tosa_spec,
5650
)
5751

58-
scale_back = 1.0
59-
if inputs[0].dtype == ts.DType.INT8:
60-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
61-
tosa_graph, inputs, node, self.tosa_spec
62-
) # type: ignore[possibly-undefined]
63-
else:
64-
# input[0].dtype == ts.DType.INT32
65-
# Non quantized input, natively support by TOSA.abs
66-
rescaled_inputs = inputs
67-
68-
if output.dtype == ts.DType.INT8:
69-
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
70-
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
71-
else:
72-
# output.dtype == ts.DType.INT32
73-
abs_output = output
74-
75-
# Do the INT32 Abs
76-
self._serialize_operator(
77-
node,
78-
tosa_graph,
52+
tosa_graph.addOperator(
7953
ts.TosaOp.Op().ABS,
8054
[
81-
rescaled_inputs[0].name,
55+
inputs[0].name,
8256
],
83-
[abs_output.name],
57+
[output.name],
8458
None,
8559
)
86-
87-
if output.dtype == ts.DType.INT8:
88-
# Scale output back to 8 bit
89-
# pyre-ignore
90-
tqutils.insert_rescale_op_to_int8(
91-
tosa_graph, abs_output, scale_back, node, self.tosa_spec
92-
) # type: ignore[possibly-undefined]
93-
94-
95-
@register_node_visitor
96-
class AbsVisitor_FP(AbsVisitor_INT):
97-
# inheriting 'target' from BI class
98-
99-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
100-
101-
def __init__(self, *args):
102-
super().__init__(*args)
103-
104-
def define_node(
105-
self,
106-
node: Node,
107-
tosa_graph: Any,
108-
inputs: List[TosaArg],
109-
output: TosaArg,
110-
) -> None:
111-
112-
import serializer.tosa_serializer as ts # type: ignore
113-
114-
validate_num_inputs(self.target, inputs, 1)
115-
validate_same_dtype(self.target, [*inputs, output], ts)
116-
117-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
118-
# Call the inherited define_node for handling integers
119-
super().define_node(node, tosa_graph, inputs, output)
120-
else:
121-
# FP32 Abs lowering
122-
123-
validate_valid_dtype(
124-
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
125-
)
126-
127-
# MI lowering
128-
self._serialize_operator(
129-
node,
130-
tosa_graph,
131-
ts.TosaOp.Op().ABS,
132-
[inputs[0].name],
133-
[output.name],
134-
None,
135-
)

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ def __init__(self):
2222
super().__init__()
2323

2424
def forward(self, x, y):
25-
a = x > y
26-
return a
25+
a = torch.abs(x)
26+
b = a > y
27+
return b
2728

2829
def get_inputs(self, dtype) -> input_t:
2930
if dtype == torch.float32:
@@ -43,8 +44,8 @@ def test_insert_rescales():
4344
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
4445
ops_after = {
4546
# "number of op nodes with i8 output" + "number of i8 node inputs"
46-
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 0
47-
+ 2,
47+
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 1
48+
+ 3,
4849
}
4950
pipeline = PassPipeline[input_t](
5051
module,

0 commit comments

Comments
 (0)