Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .convert_to_clamp import ConvertToClampPass # noqa
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DecomposeAcoshPass,
DecomposeAdaptiveAvgPool2dPass,
DecomposeAddmmPass,
DecomposeAddSubAlphaPass,
DecomposeAsinAndAcosPass,
DecomposeAsinhPass,
DecomposeAtanhPass,
Expand Down Expand Up @@ -262,6 +263,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
)
self.add_pass(DecomposeNotEqualPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeAddSubAlphaPass())
self.add_pass(DecomposeSoftmaxPass())
self.add_pass(DecomposeGeluPass())
self.add_pass(ConvertFullLikeToFullPass())
Expand Down Expand Up @@ -334,6 +336,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeSignPass())
self.add_pass(DecomposeAddmmPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(DecomposeAddSubAlphaPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeGroupNormPass())
Expand Down
94 changes: 94 additions & 0 deletions backends/arm/_passes/decompose_add_sub_alpha_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import numbers
from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


_ADD_OPS = (
exir_ops.edge.aten.add.Tensor,
torch.ops.aten.add.Tensor,
)

_SUB_OPS = (
exir_ops.edge.aten.sub.Tensor,
torch.ops.aten.sub.Tensor,
)


def _get_ops(op):
if op in _ADD_OPS:
if op is exir_ops.edge.aten.add.Tensor:
return (
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.add.Tensor,
)
return (
torch.ops.aten.mul.Tensor,
torch.ops.aten.full.default,
torch.ops.aten.add.Tensor,
)
if op in _SUB_OPS:
if op is exir_ops.edge.aten.sub.Tensor:
return (
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.sub.Tensor,
)
return (
torch.ops.aten.mul.Tensor,
torch.ops.aten.full.default,
torch.ops.aten.sub.Tensor,
)
raise RuntimeError(f"Unsupported operator {op}")


def _should_decompose(alpha) -> bool:
if isinstance(alpha, numbers.Number):
return alpha != 1
return False


class DecomposeAddSubAlphaPass(ArmPass):
"""Rewrite add/sub with alpha into a mul followed by add/sub."""

_passes_required_after: Set[Type[ExportPass]] = set()

def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
if op not in _ADD_OPS + _SUB_OPS:
return super().call_operator(op, args, kwargs, meta, updated)

alpha = kwargs.get("alpha", 1)
if not _should_decompose(alpha):
return super().call_operator(op, args, kwargs, meta, updated)

mul_op, full_op, binary_op = _get_ops(op)
lhs, rhs = args

alpha_full = super().call_operator(
full_op, ((1,), float(alpha)), {}, meta, updated=True
)
scaled_rhs = super().call_operator(
mul_op,
(rhs, alpha_full),
{},
meta,
updated=True,
)
return super().call_operator(
binary_op,
(lhs, scaled_rhs),
{},
meta,
updated=True,
)
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):

class Add3(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + y
return torch.add(x, y, alpha=1.5)

test_data: list[input_t2] = {
"3d_randn_diff_rank": lambda: (torch.randn(1, 4, 5), torch.randn(4, 1)),
Expand Down
26 changes: 26 additions & 0 deletions backends/arm/test/ops/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
return x - y


class SubAlpha(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.sub(x, y, alpha=5)


class SubTan(torch.nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor):
Expand Down Expand Up @@ -115,6 +120,18 @@ def test_sub_tensor_tosa_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
pipeline.run()


@common.parametrize("test_data", sub_tan_test_data)
def test_sub_tensor_tosa_FP_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]):
"""Test Two-Operand Subtraction with alpha (TOSA FP)"""
pipeline = TosaPipelineFP[input_t2](
SubAlpha(),
test_data(),
aten_op,
exir_op,
)
pipeline.run()


@common.parametrize("test_data", sub_test_data)
def test_sub_tensor_tosa_INT(test_data):
"""Test Subtraction (TOSA INT)"""
Expand All @@ -138,6 +155,15 @@ def test_sub_tensor_tosa_INT_3(test_data: Tuple[torch.Tensor, torch.Tensor]):
pipeline.run()


@common.parametrize("test_data", sub_tan_test_data)
def test_sub_tensor_tosa_INT_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]):
"""Test Two-Operand Subtraction with alpha (TOSA INT)"""
pipeline = TosaPipelineINT[input_t2](
SubAlpha(), test_data(), aten_op, exir_op, qtol=0
)
pipeline.run()


@common.parametrize("test_data", sub_test_data)
@common.XfailIfNoCorstone300
def test_sub_tensor_u55_INT(test_data):
Expand Down
Loading