Skip to content

Commit 36ab971

Browse files
authored
Arm backend: support torch.sum.default (#15380)
When sum is called without arguments, it is traced as sum.default rather than sum.dim_IntList, which 1) is lowered to edge with dims=[] instead of None, and 2) is not annotated in the quantizer. This commit fixes those issues and adds tests. Signed-off-by: Erik Lundell <[email protected]>
1 parent aa21a49 commit 36ab971

File tree

4 files changed

+32
-4
lines changed

4 files changed

+32
-4
lines changed

backends/arm/_passes/decompose_sum_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def call_operator(self, op, args, kwargs, meta):
6868
case _:
6969
raise ValueError(f"Invalid number of arguments ({len(args)}) provided.")
7070

71-
# If dims is None, sum over all dimensions
72-
if dims is None:
71+
# If dims evaluates to False (None or []), sum over all dimensions
72+
if not dims:
7373
shape = input_node.data.size()
7474
dims = list(range(len(shape)))
7575

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def _match_pattern(
284284
torch.ops.aten.sin.default,
285285
torch.ops.aten.tanh.default,
286286
torch.ops.aten.sum.dim_IntList,
287+
torch.ops.aten.sum.default,
287288
torch.ops.aten.hardsigmoid.default,
288289
torch.ops.aten.hardswish.default,
289290
torch.ops.aten.hardswish_.default,

backends/arm/scripts/parse_test_names.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"native_group_norm.default",
2323
"silu.default",
2424
"sdpa.default",
25+
"sum.default",
2526
"unbind.int",
2627
"unflatten.int",
2728
"_native_batch_norm_legit_no_training.default",

backends/arm/test/ops/test_sum.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Tuple
6+
from typing import Callable, Tuple
77

88
import torch
99
from executorch.backends.arm.test import common
@@ -61,7 +61,6 @@ def test_sum_dim_intlist_tosa_INT(test_data: input_t1):
6161
aten_op,
6262
exir_op=[],
6363
)
64-
pipeline.dump_artifact("export")
6564
pipeline.run()
6665

6766

@@ -133,3 +132,30 @@ def test_view_u55_INT_failure_set(test_data: Tuple):
133132
)
134133
pipeline.pop_stage("check_count.exir")
135134
pipeline.run()
135+
136+
137+
input_t2 = tuple[torch.Tensor]
138+
139+
140+
class SumDefault(torch.nn.Module):
141+
test_parameters = {
142+
"rank1": lambda: (torch.rand(10),),
143+
"rank2": lambda: (torch.rand(10, 1, 10),),
144+
"rank4": lambda: (torch.rand(1, 1, 5, 8),),
145+
}
146+
aten_op = "torch.ops.aten.sum.default"
147+
148+
def forward(self, x: torch.Tensor):
149+
return x.sum()
150+
151+
152+
@common.parametrize("test_data", SumDefault.test_parameters)
153+
def test_sum_tosa_FP(test_data: Callable[[], input_t2]):
154+
pipeline = TosaPipelineFP[input_t2](SumDefault(), test_data(), SumDefault.aten_op)
155+
pipeline.run()
156+
157+
158+
@common.parametrize("test_data", SumDefault.test_parameters)
159+
def test_sum_tosa_INT(test_data: Callable[[], input_t2]):
160+
pipeline = TosaPipelineINT[input_t1](SumDefault(), test_data(), SumDefault.aten_op)
161+
pipeline.run()

0 commit comments

Comments
 (0)