Skip to content

Commit 463ed4a

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Move rescales from MAXIMUM/MINIMUM visitor to pass
Signed-off-by: Martin Lindström <[email protected]> Change-Id: I5ee8f97590ce599a9dfc60ced0775654fa565c4e
1 parent a647bc3 commit 463ed4a

File tree

4 files changed

+23
-87
lines changed

4 files changed

+23
-87
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class InsertRescaleInt32Pass(ArmPass):
9191
exir_ops.edge.aten.gt.Tensor,
9292
exir_ops.edge.aten.le.Tensor,
9393
exir_ops.edge.aten.lt.Tensor,
94+
exir_ops.edge.aten.maximum.default,
95+
exir_ops.edge.aten.minimum.default,
9496
]
9597

9698
def _int32_qargs(self, s):
@@ -121,6 +123,8 @@ def _get_inputs_rescaled_qparams(
121123
exir_ops.edge.aten.gt.Tensor,
122124
exir_ops.edge.aten.le.Tensor,
123125
exir_ops.edge.aten.lt.Tensor,
126+
exir_ops.edge.aten.minimum.default,
127+
exir_ops.edge.aten.maximum.default,
124128
]:
125129
# For these ops, use the smallest scale among the INT8 operands.
126130
min_scale = min(
@@ -142,6 +146,8 @@ def _get_output_qparams(
142146

143147
if target in [
144148
exir_ops.edge.aten.abs.default,
149+
exir_ops.edge.aten.maximum.default,
150+
exir_ops.edge.aten.minimum.default,
145151
]:
146152
# The op has not altered the scale; the output scale is equal to
147153
# the operands' scales.

backends/arm/operators/op_maximum.py

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
12-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13-
get_input_qparams,
14-
)
15-
1610
from executorch.backends.arm.operators.node_visitor import (
1711
NodeVisitor,
1812
register_node_visitor,
@@ -22,9 +16,8 @@
2216
validate_same_dtype,
2317
validate_valid_dtype,
2418
)
25-
from executorch.backends.arm.tosa import TosaSpecification
2619
from executorch.backends.arm.tosa.mapping import TosaArg
27-
from executorch.backends.arm.tosa.utils import tosa_shape
20+
from executorch.backends.arm.tosa.specification import TosaSpecification
2821
from torch.fx import Node
2922

3023

@@ -56,51 +49,22 @@ def define_node(
5649
validate_valid_dtype(
5750
self.target,
5851
[*inputs, output],
59-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
52+
[ts.DType.INT32, ts.DType.FP32],
6053
output.tosa_spec,
6154
)
6255

63-
scale_back = 1.0
64-
max_output = output
65-
if inputs[0].dtype == ts.DType.INT8:
66-
input_qparams = get_input_qparams(node)
67-
if len(input_qparams) != 2:
68-
raise ValueError(
69-
f"Both inputs need to have quantization information for {node}"
70-
)
71-
if input_qparams[0] != input_qparams[1]:
72-
raise ValueError(
73-
"Both inputs must have the same quantization parameters for MAX"
74-
)
75-
76-
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
77-
tosa_graph, inputs, node, self.tosa_spec
78-
)
79-
80-
output.shape = tosa_shape(output.shape, output.dim_order)
81-
max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
82-
else:
83-
operand_inputs = inputs
84-
8556
attr_maximum = ts.TosaSerializerAttribute()
86-
87-
# Set to PROPOGATE as default
57+
# Set to PROPAGATE as default
8858
attr_maximum.MaximumAttribute(nan_mode=NanPropagationMode.PROPAGATE)
8959

9060
self._serialize_operator(
9161
node,
9262
tosa_graph,
9363
ts.TosaOp.Op().MAXIMUM,
9464
[
95-
operand_inputs[0].name,
96-
operand_inputs[1].name,
65+
inputs[0].name,
66+
inputs[1].name,
9767
],
98-
[max_output.name],
68+
[output.name],
9969
attr_maximum,
10070
)
101-
102-
if output.dtype == ts.DType.INT8:
103-
# insert RESCALE from int32 back to int8
104-
tqutils.insert_rescale_op_to_int8(
105-
tosa_graph, max_output, scale_back, node, self.tosa_spec
106-
)

backends/arm/operators/op_minimum.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
12-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13-
get_input_qparams,
14-
)
1510
from executorch.backends.arm.operators.node_visitor import (
1611
NodeVisitor,
1712
register_node_visitor,
@@ -23,7 +18,6 @@
2318
)
2419
from executorch.backends.arm.tosa import TosaSpecification
2520
from executorch.backends.arm.tosa.mapping import TosaArg
26-
from executorch.backends.arm.tosa.utils import tosa_shape
2721
from torch.fx import Node
2822

2923

@@ -55,51 +49,22 @@ def define_node(
5549
validate_valid_dtype(
5650
self.target,
5751
[*inputs, output],
58-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
52+
[ts.DType.INT32, ts.DType.FP32],
5953
output.tosa_spec,
6054
)
6155

62-
scale_back = 1.0
63-
min_output = output
64-
if inputs[0].dtype == ts.DType.INT8:
65-
input_qparams = get_input_qparams(node)
66-
if len(input_qparams) != 2:
67-
raise ValueError(
68-
f"Both inputs need to have quantization information for {node}"
69-
)
70-
if input_qparams[0] != input_qparams[1]:
71-
raise ValueError(
72-
"Both inputs must have the same quantization parameters for MIN"
73-
)
74-
75-
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
76-
tosa_graph, inputs, node, self.tosa_spec
77-
)
78-
79-
output.shape = tosa_shape(output.shape, output.dim_order)
80-
min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
81-
else:
82-
operand_inputs = inputs
83-
8456
attr_minimum = ts.TosaSerializerAttribute()
85-
86-
# Set to PROPOGATE as default
57+
# Set to PROPAGATE as default
8758
attr_minimum.MinimumAttribute(nan_mode=NanPropagationMode.PROPAGATE)
8859

8960
self._serialize_operator(
9061
node,
9162
tosa_graph,
9263
ts.TosaOp.Op().MINIMUM,
9364
[
94-
operand_inputs[0].name,
95-
operand_inputs[1].name,
65+
inputs[0].name,
66+
inputs[1].name,
9667
],
97-
[min_output.name],
68+
[output.name],
9869
attr_minimum,
9970
)
100-
101-
if output.dtype == ts.DType.INT8:
102-
# insert RESCALE from int32 back to int8
103-
tqutils.insert_rescale_op_to_int8(
104-
tosa_graph, min_output, scale_back, node, self.tosa_spec
105-
)

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def __init__(self):
2222
super().__init__()
2323

2424
def forward(self, x, y):
25-
a = torch.abs(x)
26-
b = a > y
27-
return b
25+
a = torch.maximum(x, y)
26+
b = torch.abs(a)
27+
c = a > b
28+
return c
2829

2930
def get_inputs(self, dtype) -> input_t:
3031
if dtype == torch.float32:
@@ -44,8 +45,8 @@ def test_insert_rescales():
4445
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
4546
ops_after = {
4647
# "number of op nodes with i8 output" + "number of i8 node inputs"
47-
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 1
48-
+ 3,
48+
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2
49+
+ 5,
4950
}
5051
pipeline = PassPipeline[input_t](
5152
module,

0 commit comments

Comments
 (0)