From fd928a6d5b653513aabe7b3220af499479d9c15d Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 24 Oct 2025 10:48:40 +0200 Subject: [PATCH] Arm backend: support torch.sum() When sum is called without arguments, it is traced as sum.default rather than sum.dim_IntList, which 1) is lowered to edge with dims=[] instead of None, and 2) is not annotated in the quantizer. This commit fixes those issues and adds tests. Signed-off-by: Erik Lundell Change-Id: Ic9ae84c62a713f136de7bff2594bacbdfc759995 --- backends/arm/_passes/decompose_sum_pass.py | 4 +-- .../arm/quantizer/quantization_annotator.py | 1 + backends/arm/scripts/parse_test_names.py | 1 + backends/arm/test/ops/test_sum.py | 30 +++++++++++++++++-- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 589dcfcefa7..d96616a6373 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -68,8 +68,8 @@ def call_operator(self, op, args, kwargs, meta): case _: raise ValueError(f"Invalid number of arguments ({len(args)}) provided.") - # If dims is None, sum over all dimensions - if dims is None: + # If dims evaluates to False (None or []), sum over all dimensions + if not dims: shape = input_node.data.size() dims = list(range(len(shape))) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index ee7003aacb8..99c18953efb 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -284,6 +284,7 @@ def _match_pattern( torch.ops.aten.sin.default, torch.ops.aten.tanh.default, torch.ops.aten.sum.dim_IntList, + torch.ops.aten.sum.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default, diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index 56a2a9c6890..a663ba2e8b7 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -22,6 +22,7 @@ "native_group_norm.default", "silu.default", "sdpa.default", + "sum.default", "unbind.int", "unflatten.int", "_native_batch_norm_legit_no_training.default", diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index f0af9a022e8..050a50e7251 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Callable, Tuple import torch from executorch.backends.arm.test import common @@ -61,7 +61,6 @@ def test_sum_dim_intlist_tosa_INT(test_data: input_t1): aten_op, exir_op=[], ) - pipeline.dump_artifact("export") pipeline.run() @@ -133,3 +132,30 @@ def test_view_u55_INT_failure_set(test_data: Tuple): ) pipeline.pop_stage("check_count.exir") pipeline.run() + + +input_t2 = tuple[torch.Tensor] + + +class SumDefault(torch.nn.Module): + test_parameters = { + "rank1": lambda: (torch.rand(10),), + "rank2": lambda: (torch.rand(10, 1, 10),), + "rank4": lambda: (torch.rand(1, 1, 5, 8),), + } + aten_op = "torch.ops.aten.sum.default" + + def forward(self, x: torch.Tensor): + return x.sum() + + +@common.parametrize("test_data", SumDefault.test_parameters) +def test_sum_tosa_FP(test_data: Callable[[], input_t2]): + pipeline = TosaPipelineFP[input_t2](SumDefault(), test_data(), SumDefault.aten_op) + pipeline.run() + + +@common.parametrize("test_data", SumDefault.test_parameters) +def test_sum_tosa_INT(test_data: Callable[[], input_t2]): + pipeline = TosaPipelineINT[input_t1](SumDefault(), test_data(), SumDefault.aten_op) + pipeline.run()