From d28d93fd3877179a5a3bc442e8afe92b8a71b8f4 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 7 Oct 2025 16:00:19 +0200 Subject: [PATCH] Arm backend: Decompose sub/add with alpha!=1 This was previously not supported, causing crashes in quantization, and incorrect output in floating point. Signed-off-by: Erik Lundell Change-Id: Ib176007ec7c0be311fac4a3fb738200cac201bbc --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 3 + .../_passes/decompose_add_sub_alpha_pass.py | 94 +++++++++++++++++++ backends/arm/test/ops/test_add.py | 2 +- backends/arm/test/ops/test_sub.py | 26 +++++ 5 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 backends/arm/_passes/decompose_add_sub_alpha_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1374ed8a3d3..b1337c38a58 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -27,6 +27,7 @@ from .convert_to_clamp import ConvertToClampPass # noqa from .decompose_acosh_pass import DecomposeAcoshPass # noqa from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa +from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa from .decompose_addmm_pass import DecomposeAddmmPass # noqa from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa from .decompose_asinh_pass import DecomposeAsinhPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index ef6d6e6810a..325f667f0ac 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -36,6 +36,7 @@ DecomposeAcoshPass, DecomposeAdaptiveAvgPool2dPass, DecomposeAddmmPass, + DecomposeAddSubAlphaPass, DecomposeAsinAndAcosPass, DecomposeAsinhPass, DecomposeAtanhPass, @@ -262,6 +263,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: ) self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeDivPass()) + self.add_pass(DecomposeAddSubAlphaPass()) self.add_pass(DecomposeSoftmaxPass()) self.add_pass(DecomposeGeluPass()) self.add_pass(ConvertFullLikeToFullPass()) @@ -334,6 +336,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeSignPass()) self.add_pass(DecomposeAddmmPass()) self.add_pass(DecomposeDivTensorModePass()) + self.add_pass(DecomposeAddSubAlphaPass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeGroupNormPass()) diff --git a/backends/arm/_passes/decompose_add_sub_alpha_pass.py b/backends/arm/_passes/decompose_add_sub_alpha_pass.py new file mode 100644 index 00000000000..c0ed1bae09b --- /dev/null +++ b/backends/arm/_passes/decompose_add_sub_alpha_pass.py @@ -0,0 +1,94 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import numbers +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +_ADD_OPS = ( + exir_ops.edge.aten.add.Tensor, + torch.ops.aten.add.Tensor, +) + +_SUB_OPS = ( + exir_ops.edge.aten.sub.Tensor, + torch.ops.aten.sub.Tensor, +) + + +def _get_ops(op): + if op in _ADD_OPS: + if op is exir_ops.edge.aten.add.Tensor: + return ( + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.add.Tensor, + ) + return ( + torch.ops.aten.mul.Tensor, + torch.ops.aten.full.default, + torch.ops.aten.add.Tensor, + ) + if op in _SUB_OPS: + if op is exir_ops.edge.aten.sub.Tensor: + return ( + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.sub.Tensor, + ) + return ( + torch.ops.aten.mul.Tensor, + torch.ops.aten.full.default, + torch.ops.aten.sub.Tensor, + ) + raise RuntimeError(f"Unsupported operator {op}") + + +def _should_decompose(alpha) -> bool: + if isinstance(alpha, numbers.Number): + return alpha != 1 + return False + + +class DecomposeAddSubAlphaPass(ArmPass): + """Rewrite add/sub with alpha into a mul followed by add/sub.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): + if op not in _ADD_OPS + _SUB_OPS: + return super().call_operator(op, args, kwargs, meta, updated) + + alpha = kwargs.get("alpha", 1) + if not _should_decompose(alpha): + return super().call_operator(op, args, kwargs, meta, updated) + + mul_op, full_op, binary_op = _get_ops(op) + lhs, rhs = args + + alpha_full = super().call_operator( + full_op, ((1,), float(alpha)), {}, meta, updated=True + ) + scaled_rhs = super().call_operator( + mul_op, + (rhs, alpha_full), + {}, + meta, + updated=True, + ) + return super().call_operator( + binary_op, + (lhs, scaled_rhs), + {}, + meta, + updated=True, + ) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 9b3f98763c6..bcab40116d8 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -78,7 +78,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): class Add3(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): - return x + y + return torch.add(x, y, alpha=1.5) test_data: list[input_t2] = { "3d_randn_diff_rank": lambda: (torch.randn(1, 4, 5), torch.randn(4, 1)), diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index 9c02243f30f..68b6ad5fb93 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -79,6 +79,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): return x - y +class SubAlpha(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.sub(x, y, alpha=5) + + class SubTan(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): @@ -115,6 +120,18 @@ def test_sub_tensor_tosa_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]): pipeline.run() +@common.parametrize("test_data", sub_tan_test_data) +def test_sub_tensor_tosa_FP_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction with alpha (TOSA FP)""" + pipeline = TosaPipelineFP[input_t2]( + SubAlpha(), + test_data(), + aten_op, + exir_op, + ) + pipeline.run() + + @common.parametrize("test_data", sub_test_data) def test_sub_tensor_tosa_INT(test_data): """Test Subtraction (TOSA INT)""" @@ -138,6 +155,15 @@ def test_sub_tensor_tosa_INT_3(test_data: Tuple[torch.Tensor, torch.Tensor]): pipeline.run() +@common.parametrize("test_data", sub_tan_test_data) +def test_sub_tensor_tosa_INT_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction with alpha (TOSA INT)""" + pipeline = TosaPipelineINT[input_t2]( + SubAlpha(), test_data(), aten_op, exir_op, qtol=0 + ) + pipeline.run() + + @common.parametrize("test_data", sub_test_data) @common.XfailIfNoCorstone300 def test_sub_tensor_u55_INT(test_data):