Skip to content

Commit f4fa279

Browse files
Arm backend: Add cumsum support (#13457)
Decompose cumsum as a convolution with a kernel of ones. Signed-off-by: Adrian Lundell <[email protected]>
1 parent 1c4de12 commit f4fa279

File tree

6 files changed

+270
-0
lines changed

6 files changed

+270
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
3434
from .decompose_cosh_pass import DecomposeCoshPass # noqa
3535
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
36+
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
3637
from .decompose_div_pass import DecomposeDivPass # noqa
3738
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
3839
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
DecomposeBatchNormNoStatsPass,
3939
DecomposeCoshPass,
4040
DecomposeCosineSimilarityPass,
41+
DecomposeCumsumPass,
4142
DecomposeDivPass,
4243
DecomposeEmbeddingPass,
4344
DecomposeExpm1Pass,
@@ -148,6 +149,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
148149
self.add_pass(UnsqueezeBeforeRepeatPass())
149150
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
150151
self.add_pass(DecomposeSumPass())
152+
self.add_pass(DecomposeCumsumPass(exported_program))
151153
self.add_pass(Conv1dUnsqueezePass())
152154
self.add_pass(DecomposeMaxPool2DPass())
153155
self.add_pass(SizeAdjustInputPass())
@@ -227,6 +229,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
227229
self.add_pass(UnsqueezeBeforeRepeatPass())
228230
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
229231
self.add_pass(DecomposeSumPass())
232+
self.add_pass(DecomposeCumsumPass(exported_program))
230233
self.add_pass(Conv1dUnsqueezePass())
231234
self.add_pass(DecomposeMaxPool2DPass())
232235
self.add_pass(SizeAdjustInputPass())
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from math import prod
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.backends.arm._passes.quant_args import QuantArgs
12+
13+
from executorch.backends.transforms.utils import create_constant_placeholder
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import PassResult
16+
from torch.export.graph_signature import InputKind
17+
18+
19+
class DecomposeCumsumPass(ArmPass):
20+
"""
21+
Decomposes cumsum into a 1D convolution with a kernel of ones.
22+
23+
For example, the cumsum of an input tensor [1, 1] is [1, 1 + 1] = [1, 2].
24+
To decompose this, take the input tensor and pre-padded with len(input)-1 zeros and
25+
slided over with a kernel [1,1], of length len(input):
26+
27+
Input: [0, 1, 1]
28+
Kernel: [1, 1] = [1]
29+
[1, 1] = [2]
30+
31+
Since pytorch only supports symmetric padding, in reality the result will have
32+
an additional 1 calculated at the end, which leads to an required extra slice op.
33+
34+
To extend this to higher dimensions, the input is reshaped to [N, C, H, W] with
35+
N = <dims before cumsum dim>
36+
C = 1
37+
H = <cumsum dim>
38+
W = <dims after cumsum dim>
39+
And the convolution is applied over dimension H.
40+
"""
41+
42+
def call(self, graph_module):
43+
graph = graph_module.graph
44+
targets = (exir_ops.edge.aten.cumsum.default, torch.ops.aten.cumsum.default)
45+
modified = False
46+
for node in list(graph.nodes):
47+
if node.op != "call_function" or node.target not in targets:
48+
continue
49+
50+
if len(node.args) != 2:
51+
raise ValueError(
52+
"Cumsum node should have exactly two arguments: input and dim."
53+
)
54+
55+
# Get node data
56+
input_node, dim = node.args
57+
val = node.meta.get("val")
58+
original_shape = list(val.shape)
59+
dtype = input_node.meta.get("val").dtype
60+
dim = dim % len(original_shape)
61+
62+
# Compute shapes
63+
pre_cumsum_dim = prod(original_shape[:dim]) if dim > 0 else 1
64+
cumsum_dim = original_shape[dim]
65+
post_cumsum_dim = (
66+
prod(original_shape[dim + 1 :]) if dim < len(original_shape) - 1 else 1
67+
)
68+
conv_shape = [
69+
pre_cumsum_dim,
70+
1,
71+
cumsum_dim,
72+
post_cumsum_dim,
73+
]
74+
pad_shape = [original_shape[dim] - 1, 0]
75+
weight_shape = [1, 1, original_shape[dim], 1]
76+
77+
# Create convolution weight
78+
with graph.inserting_before(list(graph.nodes)[0]):
79+
weight_data = torch.ones(size=weight_shape, dtype=dtype)
80+
weight_node = create_constant_placeholder(
81+
self.exported_program,
82+
graph,
83+
node.name + "_kernel",
84+
InputKind.PARAMETER,
85+
weight_data,
86+
)
87+
88+
# Create decomposed nodes
89+
view_op = exir_ops.edge.aten.view_copy.default
90+
conv_op = exir_ops.edge.aten.convolution.default
91+
slice_op = exir_ops.edge.aten.slice_copy.Tensor
92+
with graph.inserting_before(node):
93+
# Reshape to 4D with
94+
view_args = (input_node, conv_shape)
95+
view_node = create_node(graph, view_op, args=view_args, from_node=node)
96+
97+
conv_args = (
98+
view_node,
99+
weight_node,
100+
None,
101+
[1, 1],
102+
pad_shape,
103+
[1, 1],
104+
False,
105+
[0],
106+
1,
107+
)
108+
conv_node = create_node(graph, conv_op, args=conv_args, from_node=node)
109+
110+
# The convolution is inserted after quantization, so we need to set our
111+
# own quantization parameters for the weights here. However since the
112+
# data is ones directly created as int8, they already have correct scale
113+
# and so no scaling needs to be done, i.e. set scale=1.0, zero_point=0.0
114+
if (
115+
"input_qparams" in conv_node.meta
116+
and len(conv_node.meta["input_qparams"]) > 0
117+
):
118+
qparams = QuantArgs(1.0, 0.0, -128, 127, torch.int8)
119+
conv_node.meta["input_qparams"][1] = qparams
120+
121+
slice_args = (conv_node, 2, 0, original_shape[dim])
122+
slice_node = create_node(
123+
graph, slice_op, args=slice_args, from_node=node
124+
)
125+
126+
view_original_args = (slice_node, original_shape)
127+
view_original_node = create_node(
128+
graph, view_op, args=view_original_args, from_node=node
129+
)
130+
131+
# Replace and remove original
132+
node.replace_all_uses_with(view_original_node)
133+
graph.erase_node(node)
134+
modified = True
135+
136+
if modified:
137+
# Cleanup
138+
graph.eliminate_dead_code()
139+
graph_module.recompile()
140+
# Apply any operator-level transforms
141+
graph_module = super().call(graph_module).graph_module
142+
return PassResult(graph_module, modified)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def is_node_supported(
169169
exir_ops.edge.aten.cat.default,
170170
exir_ops.edge.aten.ceil.default,
171171
exir_ops.edge.aten.clamp.default,
172+
exir_ops.edge.aten.cumsum.default,
172173
exir_ops.edge.aten.bmm.default,
173174
exir_ops.edge.aten.permute_copy.default,
174175
exir_ops.edge.aten.hardsigmoid.default,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def _match_pattern(
290290
torch.ops.aten.asinh.default,
291291
torch.ops.aten.cosh.default,
292292
torch.ops.aten.acos.default,
293+
torch.ops.aten.cumsum.default,
293294
]
294295

295296
_one_to_one_shared_input_qspec = [

backends/arm/test/ops/test_cumsum.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
input_t1 = Tuple[torch.Tensor, int]
20+
aten_op = "torch.ops.aten.cumsum.default"
21+
22+
"""
23+
Tests the aten.cumsum operator by decomposing it into a convolution and
24+
verifying results across various dims and pipelines.
25+
"""
26+
27+
28+
class CumsumModule(torch.nn.Module):
29+
test_parameters = {
30+
"1d_dim0": lambda: (torch.rand(10), 0),
31+
"1d_dim_neg1": lambda: (torch.rand(10), -1),
32+
"2d_dim1": lambda: (torch.rand(5, 6), 1),
33+
"3d_dim2": lambda: (torch.rand(2, 3, 4), 2),
34+
"3d_dim0": lambda: (torch.rand(2, 3, 4), 0),
35+
"4d_dim3": lambda: (torch.rand(1, 2, 3, 4), 3),
36+
"4d_dim1": lambda: (torch.rand(1, 2, 3, 4), 1),
37+
}
38+
39+
def forward(self, x: torch.Tensor, dim: int) -> torch.Tensor:
40+
return torch.cumsum(x, dim)
41+
42+
43+
@common.parametrize("test_data", CumsumModule.test_parameters)
44+
def test_cumsum_tosa_FP(test_data: input_t1):
45+
module = CumsumModule()
46+
args = test_data()
47+
pipeline = TosaPipelineFP[input_t1](
48+
module,
49+
args,
50+
aten_op,
51+
exir_op=[],
52+
)
53+
pipeline.run()
54+
55+
56+
@common.parametrize("test_data", CumsumModule.test_parameters)
57+
def test_cumsum_tosa_INT(test_data: input_t1):
58+
module = CumsumModule()
59+
args = test_data()
60+
pipeline = TosaPipelineINT[input_t1](
61+
module,
62+
args,
63+
aten_op,
64+
exir_op=[],
65+
)
66+
pipeline.run()
67+
68+
69+
@common.parametrize("test_data", CumsumModule.test_parameters)
70+
@common.SkipIfNoModelConverter
71+
def test_cumsum_vgf_FP(test_data: input_t1):
72+
module = CumsumModule()
73+
args = test_data()
74+
pipeline = VgfPipeline[input_t1](
75+
module,
76+
args,
77+
aten_op,
78+
tosa_version="TOSA-1.0+FP",
79+
)
80+
pipeline.run()
81+
82+
83+
@common.parametrize("test_data", CumsumModule.test_parameters)
84+
@common.SkipIfNoModelConverter
85+
def test_cumsum_vgf_INT(test_data: input_t1):
86+
module = CumsumModule()
87+
args = test_data()
88+
pipeline = VgfPipeline[input_t1](
89+
module,
90+
args,
91+
aten_op,
92+
tosa_version="TOSA-1.0+INT",
93+
)
94+
pipeline.run()
95+
96+
97+
@common.parametrize("test_data", CumsumModule.test_parameters)
98+
@common.XfailIfNoCorstone300
99+
def test_cumsum_u55_INT(test_data: input_t1):
100+
module = CumsumModule()
101+
args = test_data()
102+
pipeline = EthosU55PipelineINT[input_t1](
103+
module,
104+
args,
105+
aten_ops=aten_op,
106+
exir_ops=[],
107+
)
108+
pipeline.run()
109+
110+
111+
@common.parametrize("test_data", CumsumModule.test_parameters)
112+
@common.XfailIfNoCorstone320
113+
def test_cumsum_u85_INT(test_data: input_t1):
114+
module = CumsumModule()
115+
args = test_data()
116+
pipeline = EthosU85PipelineINT[input_t1](
117+
module,
118+
args,
119+
aten_ops=aten_op,
120+
exir_ops=[],
121+
)
122+
pipeline.run()

0 commit comments

Comments
 (0)