Skip to content

Commit 6cb132f

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Move rescales from SUM visitor to pass
In the SUM node visitor, an INT8->INT32 RESCALE node is inserted prior to the SUM node; similarly, an INT32->INT8 RESCALE node is inserted after. This patch moves the insertion to `InsertRescaleInt32Pass`. Since SUM is decomposed, insertion of RESCALE nodes should be carried out before `DecomposeSumPass` (which decomposes SUM into a chain of single dim SUMs). The ordering is important to avoid redundant INT8/INT32 RESCALE nodes being inserted between each SUM node in the chain after decomposition. Only one INT8->INT32 RESCALE is needed before the chain, and an INT32->INT8 after it; between the SUM nodes in the chain, the edges are already in the correct INT32 data type. Signed-off-by: Martin Lindström <[email protected]> Change-Id: I86dd5c34b50ca6cbba6ad98e1490c9b7effc3b3b
1 parent fedaa2d commit 6cb132f

File tree

7 files changed

+64
-88
lines changed

7 files changed

+64
-88
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
194194
self.add_pass(ConvertExpandCopyToRepeatPass())
195195
self.add_pass(UnsqueezeBeforeRepeatPass())
196196
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
197-
self.add_pass(DecomposeSumPass())
198197
self.add_pass(DecomposeCumsumPass(exported_program))
199198
self.add_pass(Conv1dUnsqueezePass())
200199
self.add_pass(DecomposeMaxPool2DPass())
@@ -215,10 +214,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
215214
self.add_pass(RewriteMatmulPass())
216215
self.add_pass(RewriteUpsamplePass())
217216
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
217+
self.add_pass(InsertRescaleInt32Pass())
218+
self.add_pass(DecomposeSumPass())
218219
self.add_pass(ToTosaMemoryFormatPass(exported_program))
219220
self.add_pass(RemoveNoopPass())
220221
self.add_pass(InsertRescalePass())
221-
self.add_pass(InsertRescaleInt32Pass())
222222

223223
self.validate_constraints_mandatory()
224224
return self._transform(exported_program.graph_module)
@@ -361,7 +361,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
361361

362362
self.add_pass(ConvertMinMaxPass())
363363
self.add_pass(ReplaceInfValues())
364-
self.add_pass(DecomposeSumPass())
365364

366365
if not self.tosa_spec.is_U55_subset:
367366
# Uses where which is not supported on Ethos-U55

backends/arm/_passes/decompose_sum_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta):
8383
if not keepdims:
8484
shape = list(meta["val"].size())
8585
input_node = super().call_operator(
86-
view_op, (input_node, shape), kwargs, meta, updated=True
86+
view_op, (input_node, shape), {}, meta, updated=True
8787
)
8888

8989
return input_node

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from executorch.backends.arm._passes.arm_pass import ArmPass
1212
from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg
13+
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
1314
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1415
get_output_qparams,
1516
)
@@ -84,7 +85,11 @@ class InsertRescaleInt32Pass(ArmPass):
8485
parameters.
8586
"""
8687

87-
_passes_required_after: Set[Type[ExportPass]] = set()
88+
# SUM must be decomposed after this pass to prevent insertion of RESCALE
89+
# nodes between each subsequent SUM node after decomposition. RESCALE nodes
90+
# should only be inserted before and after the SUM node prior to its
91+
# decomposition.
92+
_passes_required_after: Set[Type[ExportPass]] = {DecomposeSumPass}
8893

8994
included_targets = [
9095
exir_ops.edge.aten.abs.default,
@@ -96,6 +101,7 @@ class InsertRescaleInt32Pass(ArmPass):
96101
exir_ops.edge.aten.maximum.default,
97102
exir_ops.edge.aten.minimum.default,
98103
exir_ops.edge.aten.mul.Tensor,
104+
exir_ops.edge.aten.sum.dim_IntList,
99105
]
100106

101107
def _int32_qargs(self, s):
@@ -138,6 +144,7 @@ def _get_inputs_rescaled_qparams(
138144
}
139145
elif target in [
140146
exir_ops.edge.aten.mul.Tensor,
147+
exir_ops.edge.aten.sum.dim_IntList,
141148
]:
142149
# The input scales do not need to be adjusted for these ops; they
143150
# can remain the same.
@@ -160,6 +167,7 @@ def _get_output_qparams(
160167
exir_ops.edge.aten.abs.default,
161168
exir_ops.edge.aten.maximum.default,
162169
exir_ops.edge.aten.minimum.default,
170+
exir_ops.edge.aten.sum.dim_IntList,
163171
]:
164172
# The op has not altered the scale; the output scale is equal to
165173
# the operands' scales.

backends/arm/operator_support/reduce_sum_support.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
2929

3030
# U55 case, Vela 4.2.0 (25.02 release)
3131
input_shape = node.all_input_nodes[0].meta["val"].shape
32-
dim_list = cast(list[int], node.args[1])
33-
dim_list = [dim % len(input_shape) for dim in dim_list]
32+
33+
if node.args[1] is None:
34+
# Dim is allowed to be None, which means to sum all dimensions
35+
dim_list = list(range(len(input_shape)))
36+
else:
37+
dim_list = cast(list[int], node.args[1])
38+
dim_list = [dim % len(input_shape) for dim in dim_list]
3439

3540
for dim in dim_list:
3641
if not 1 <= input_shape[dim] <= 65536:

backends/arm/operators/op_sum.py

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
import executorch.backends.arm.tosa.utils as tutils
12-
1310
from executorch.backends.arm.operators.node_visitor import (
1411
NodeVisitor,
1512
register_node_visitor,
@@ -24,72 +21,14 @@
2421

2522

2623
@register_node_visitor
27-
class SumVisitor_INT(NodeVisitor):
24+
class SumVisitor(NodeVisitor):
2825
target = "aten.sum.dim_IntList"
2926

3027
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3129
TosaSpecification.create_from_string("TOSA-1.0+INT"),
3230
]
3331

34-
def __init__(self, *args):
35-
super().__init__(*args)
36-
37-
def define_node(
38-
self,
39-
node: Node,
40-
tosa_graph: Any,
41-
inputs: List[TosaArg],
42-
output: TosaArg,
43-
) -> None:
44-
45-
import serializer.tosa_serializer as ts # type: ignore
46-
47-
validate_num_inputs(self.target, inputs, 3)
48-
validate_same_dtype(self.target, [inputs[0], output], ts)
49-
50-
tensor = inputs[0]
51-
input_shape = list(tensor.shape)
52-
dim = int(inputs[1].number % len(input_shape))
53-
54-
output_shape = input_shape
55-
output_shape[dim] = 1 # Output shape is input shape with dim reduced
56-
57-
# Rescale input to 32 bit
58-
rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32(
59-
tosa_graph, [tensor], node, self.tosa_spec
60-
)
61-
62-
attr = ts.TosaSerializerAttribute()
63-
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
64-
65-
intermediate = tosa_graph.addIntermediate(
66-
tutils.tosa_shape(output_shape, tensor.dim_order),
67-
dtype=ts.DType.INT32,
68-
)
69-
70-
self._serialize_operator(
71-
node,
72-
tosa_graph,
73-
ts.TosaOp.Op().REDUCE_SUM,
74-
[rescaled_inputs[0].name],
75-
[intermediate.name],
76-
attr,
77-
)
78-
79-
tqutils.insert_rescale_op_to_int8(
80-
tosa_graph, intermediate, scale, node, self.tosa_spec
81-
)
82-
83-
84-
@register_node_visitor
85-
class SumVisitor_FP(SumVisitor_INT):
86-
# inheriting 'target' from INT class
87-
88-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
89-
90-
def __init__(self, *args):
91-
super().__init__(*args)
92-
9332
def define_node(
9433
self,
9534
node: Node,
@@ -107,17 +46,14 @@ def define_node(
10746
input_shape = list(tensor.shape)
10847
dim = int(inputs[1].number % len(input_shape))
10948

110-
output_shape = input_shape
111-
output_shape[dim] = 1 # Output shape is input shape with dim reduced
112-
11349
attr = ts.TosaSerializerAttribute()
11450
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
11551

11652
self._serialize_operator(
11753
node,
11854
tosa_graph,
11955
ts.TosaOp.Op().REDUCE_SUM,
120-
[tensor.name],
56+
[inputs[0].name],
12157
[output.name],
12258
attr,
12359
)

backends/arm/test/ops/test_sum.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Sum(torch.nn.Module):
3535
"4d_dim_3_keep": lambda: (torch.rand(1, 2, 3, 4), 3, True),
3636
"4d_dims_keep": lambda: (torch.rand(1, 2, 8, 8), [2, 3, 0], True),
3737
"dim_None": lambda: (torch.rand(10), None, True),
38+
"dim_None_4d_tensor": lambda: (torch.rand(10, 3, 2, 1), None, True),
3839
}
3940

4041
def forward(self, x: torch.Tensor, dim: int, keepdim: bool):

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@
1313
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1414

1515

16-
class NeedsRescaleOps(torch.nn.Module):
16+
class MultipleOpsModel(torch.nn.Module):
1717
"""A module containing ops that require INT32 inputs/outputs."""
1818

1919
input_t = Tuple[torch.Tensor, torch.Tensor]
2020

21-
def __init__(self):
22-
super().__init__()
23-
2421
def forward(self, x, y):
2522
a = x * y
2623
b = torch.maximum(a, y)
@@ -39,19 +36,41 @@ def get_inputs(self, dtype) -> input_t:
3936
else:
4037
raise ValueError("Not a valid input dtype for model")
4138

39+
def get_num_expected_rescales(self):
40+
# "number of op nodes with i8 output" + "number of i8 node inputs"
41+
return 3 + 7
4242

43-
def test_insert_rescales():
44-
module = NeedsRescaleOps()
45-
input_t = Tuple[torch.Tensor, torch.Tensor]
43+
44+
class SumModel(torch.nn.Module):
45+
input_t = Tuple[torch.Tensor]
46+
47+
def forward(self, x):
48+
a = torch.sum(x, 2, keepdim=True) # (1, 2, 1, 4)
49+
b = torch.sum(a, [1, 3], keepdim=True) # (1, 1, 1, 1)
50+
c = torch.sum(b, [0, 2], keepdim=False) # (1, 1)
51+
return c
52+
53+
def get_inputs(self, dtype) -> input_t:
54+
if dtype == torch.float32:
55+
return (torch.rand(1, 2, 3, 4),)
56+
elif dtype == torch.int32:
57+
return (torch.randint(0, 10, (1, 2, 3, 4), dtype=torch.int32),)
58+
else:
59+
raise ValueError("Not a valid input dtype for model")
60+
61+
def get_num_expected_rescales(self):
62+
# Two RESCALE nodes per SUM node
63+
return 6
64+
65+
66+
def _test_model_with_f32_data(model):
4667
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
4768
ops_after = {
48-
# "number of op nodes with i8 output" + "number of i8 node inputs"
49-
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3
50-
+ 7,
69+
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": model.get_num_expected_rescales(),
5170
}
52-
pipeline = PassPipeline[input_t](
53-
module,
54-
module.get_inputs(torch.float32),
71+
pipeline = PassPipeline[model.input_t](
72+
model,
73+
model.get_inputs(torch.float32),
5574
quantize=True,
5675
ops_not_before_pass=ops_not_before,
5776
ops_after_pass=ops_after,
@@ -61,8 +80,16 @@ def test_insert_rescales():
6180
pipeline.run()
6281

6382

83+
def test_insert_rescales_sum_model():
84+
_test_model_with_f32_data(SumModel())
85+
86+
87+
def test_insert_rescales_multiple_ops_model():
88+
_test_model_with_f32_data(MultipleOpsModel())
89+
90+
6491
def test_dont_insert_rescales():
65-
module = NeedsRescaleOps()
92+
module = MultipleOpsModel()
6693
input_t = Tuple[torch.Tensor, torch.Tensor]
6794
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
6895
# All inputs are already i32. Rescales should not be added.

0 commit comments

Comments
 (0)