Skip to content

Commit a113625

Browse files
committed
Arm backend: Decompose sum in pass
Moves the unrolling of reducing multiple indices from the sum node visitor to a new DecomposeSumPass. KeepDimsFalseToSqueezePass is merged into the new pass to decompose the sum op fully in one pass. This change introduces new rescales for each reduced dim, requiring decomposition before quantization to get proper quantization parameters. Change-Id: I1b113813f22c6b25aac56d63110d7eee4833167a Signed-off-by: Adrian Lundell <[email protected]>
1 parent dcd25eb commit a113625

File tree

5 files changed

+156
-205
lines changed

5 files changed

+156
-205
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
3333
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
3434
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
35+
from .decompose_sum_pass import DecomposeSumPass # noqa
3536
from .decompose_var_pass import DecomposeVarPass # noqa
3637
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
3738
FoldAndAnnotateQParamsPass,
@@ -44,7 +45,6 @@
4445
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
4546
from .insert_rescales_pass import InsertRescalePass # noqa
4647
from .insert_table_ops import InsertTableOpsPass # noqa
47-
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
4848
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
4949
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
5050
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
DecomposeSoftmaxPass,
3838
DecomposeSoftmaxUnstablePass,
3939
DecomposeSqrtPass,
40+
DecomposeSumPass,
4041
DecomposeVarPass,
4142
FoldAndAnnotateQParamsPass,
4243
FuseBatchnorm2DPass,
@@ -45,7 +46,6 @@
4546
FuseQuantizedActivationPass,
4647
InsertRescalePass,
4748
InsertTableOpsPass,
48-
KeepDimsFalseToSqueezePass,
4949
MatchArgRanksPass,
5050
MatchWhereSelfDtypePass,
5151
QuantizeOperatorArguments,
@@ -110,7 +110,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
110110
self.add_pass(ConvertExpandCopyToRepeatPass())
111111
self.add_pass(UnsqueezeBeforeRepeatPass())
112112
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
113-
self.add_pass(KeepDimsFalseToSqueezePass())
113+
self.add_pass(DecomposeSumPass())
114114
self.add_pass(Conv1dUnsqueezePass(exported_program))
115115
self.add_pass(DecomposeSelectPass())
116116
self.add_pass(ConvertSqueezesToViewPass())
@@ -163,7 +163,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
163163
self.add_pass(ConvertExpandCopyToRepeatPass())
164164
self.add_pass(UnsqueezeBeforeRepeatPass())
165165
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
166-
self.add_pass(KeepDimsFalseToSqueezePass())
166+
self.add_pass(DecomposeSumPass())
167167
self.add_pass(Conv1dUnsqueezePass(exported_program))
168168
self.add_pass(DecomposeSelectPass())
169169
self.add_pass(ConvertSqueezesToViewPass())
@@ -220,4 +220,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
220220

221221
self.add_pass(ConvertMinMaxPass())
222222
self.add_pass(ReplaceInfValues())
223+
self.add_pass(DecomposeSumPass())
224+
223225
return self._transform(graph_module)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from executorch.exir.pass_base import ExportPass
9+
10+
11+
def _get_sum_decomp(op):
12+
match op:
13+
case exir_ops.edge.aten.sum.dim_IntList:
14+
return (
15+
exir_ops.edge.aten.view_copy.default,
16+
exir_ops.edge.aten.sum.dim_IntList,
17+
)
18+
case torch.ops.aten.sum.dim_IntList:
19+
return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList)
20+
case _:
21+
raise RuntimeError("Unvalid op in DecomposeSumPass")
22+
23+
24+
class DecomposeSumPass(ExportPass):
25+
"""
26+
In Pytorch, the default behaviour of for example Tensor.sum is to squeeze the
27+
dimension that is summed (keep_dim = False). However, in TOSA, REDUCE_SUM always
28+
preserves the rank of the input (keep_dim = True). To get a 1-1 mapping in the sum
29+
lowering, normalize the keep_dim = False case to keep_dim = True and lower the rank
30+
with a view op.
31+
32+
Since TOSA can only reduce one dimension at a time, multiple dims are additionally
33+
unrolled into multiple ops.
34+
35+
Original:
36+
sum((dim_1, dim_2), keep_dim = False) -> squeezed_shape
37+
After pass:
38+
sum(dim_1, keep_dim = True) -> unsqueezed_shape
39+
sum(dim_2, keep_dim = True) -> unsqueezed_shape
40+
view(shape = squeezed_shape) -> squeezed_shape
41+
"""
42+
43+
def call_operator(self, op, args, kwargs, meta):
44+
if op not in [
45+
exir_ops.edge.aten.sum.dim_IntList,
46+
torch.ops.aten.sum.dim_IntList,
47+
]:
48+
return super().call_operator(op, args, kwargs, meta)
49+
50+
match len(args):
51+
case 3:
52+
(
53+
input_node,
54+
dims,
55+
keepdims,
56+
) = args
57+
case 2:
58+
(
59+
input_node,
60+
dims,
61+
) = args
62+
keepdims = False
63+
case _:
64+
raise ValueError(f"Invalid number of arguments ({len(args)}) provided.")
65+
66+
view_op, sum_op = _get_sum_decomp(op)
67+
68+
for dim in dims:
69+
input_node = super().call_operator(
70+
sum_op, (input_node, dim, True), kwargs, meta
71+
)
72+
73+
if not keepdims:
74+
shape = list(meta["val"].size())
75+
input_node = super().call_operator(
76+
view_op, (input_node, shape), kwargs, meta
77+
)
78+
79+
return input_node

backends/arm/_passes/keep_dims_false_to_squeeze_pass.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

0 commit comments

Comments
 (0)