| 
 | 1 | +# Copyright 2025 Arm Limited and/or its affiliates.  | 
 | 2 | +#  | 
 | 3 | +# This source code is licensed under the BSD-style license found in the  | 
 | 4 | +# LICENSE file in the root directory of this source tree.  | 
 | 5 | + | 
 | 6 | +from __future__ import annotations  | 
 | 7 | + | 
 | 8 | +import numbers  | 
 | 9 | +from typing import Set, Type  | 
 | 10 | + | 
 | 11 | +import torch  | 
 | 12 | +from executorch.backends.arm._passes import ArmPass  | 
 | 13 | +from executorch.exir.dialects._ops import ops as exir_ops  | 
 | 14 | +from executorch.exir.pass_base import ExportPass  | 
 | 15 | + | 
 | 16 | + | 
 | 17 | +_ADD_OPS = (  | 
 | 18 | +    exir_ops.edge.aten.add.Tensor,  | 
 | 19 | +    torch.ops.aten.add.Tensor,  | 
 | 20 | +)  | 
 | 21 | + | 
 | 22 | +_SUB_OPS = (  | 
 | 23 | +    exir_ops.edge.aten.sub.Tensor,  | 
 | 24 | +    torch.ops.aten.sub.Tensor,  | 
 | 25 | +)  | 
 | 26 | + | 
 | 27 | + | 
 | 28 | +def _get_ops(op):  | 
 | 29 | +    if op in _ADD_OPS:  | 
 | 30 | +        if op is exir_ops.edge.aten.add.Tensor:  | 
 | 31 | +            return (  | 
 | 32 | +                exir_ops.edge.aten.mul.Tensor,  | 
 | 33 | +                exir_ops.edge.aten.full.default,  | 
 | 34 | +                exir_ops.edge.aten.add.Tensor,  | 
 | 35 | +            )  | 
 | 36 | +        return (  | 
 | 37 | +            torch.ops.aten.mul.Tensor,  | 
 | 38 | +            torch.ops.aten.full.default,  | 
 | 39 | +            torch.ops.aten.add.Tensor,  | 
 | 40 | +        )  | 
 | 41 | +    if op in _SUB_OPS:  | 
 | 42 | +        if op is exir_ops.edge.aten.sub.Tensor:  | 
 | 43 | +            return (  | 
 | 44 | +                exir_ops.edge.aten.mul.Tensor,  | 
 | 45 | +                exir_ops.edge.aten.full.default,  | 
 | 46 | +                exir_ops.edge.aten.sub.Tensor,  | 
 | 47 | +            )  | 
 | 48 | +        return (  | 
 | 49 | +            torch.ops.aten.mul.Tensor,  | 
 | 50 | +            torch.ops.aten.full.default,  | 
 | 51 | +            torch.ops.aten.sub.Tensor,  | 
 | 52 | +        )  | 
 | 53 | +    raise RuntimeError(f"Unsupported operator {op}")  | 
 | 54 | + | 
 | 55 | + | 
 | 56 | +def _should_decompose(alpha) -> bool:  | 
 | 57 | +    if isinstance(alpha, numbers.Number):  | 
 | 58 | +        return alpha != 1  | 
 | 59 | +    return False  | 
 | 60 | + | 
 | 61 | + | 
 | 62 | +class DecomposeAddSubAlphaPass(ArmPass):  | 
 | 63 | +    """Rewrite add/sub with alpha into a mul followed by add/sub."""  | 
 | 64 | + | 
 | 65 | +    _passes_required_after: Set[Type[ExportPass]] = set()  | 
 | 66 | + | 
 | 67 | +    def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):  | 
 | 68 | +        if op not in _ADD_OPS + _SUB_OPS:  | 
 | 69 | +            return super().call_operator(op, args, kwargs, meta, updated)  | 
 | 70 | + | 
 | 71 | +        alpha = kwargs.get("alpha", 1)  | 
 | 72 | +        if not _should_decompose(alpha):  | 
 | 73 | +            return super().call_operator(op, args, kwargs, meta, updated)  | 
 | 74 | + | 
 | 75 | +        mul_op, full_op, binary_op = _get_ops(op)  | 
 | 76 | +        lhs, rhs = args  | 
 | 77 | + | 
 | 78 | +        alpha_full = super().call_operator(  | 
 | 79 | +            full_op, ((1,), float(alpha)), {}, meta, updated=True  | 
 | 80 | +        )  | 
 | 81 | +        scaled_rhs = super().call_operator(  | 
 | 82 | +            mul_op,  | 
 | 83 | +            (rhs, alpha_full),  | 
 | 84 | +            {},  | 
 | 85 | +            meta,  | 
 | 86 | +            updated=True,  | 
 | 87 | +        )  | 
 | 88 | +        return super().call_operator(  | 
 | 89 | +            binary_op,  | 
 | 90 | +            (lhs, scaled_rhs),  | 
 | 91 | +            {},  | 
 | 92 | +            meta,  | 
 | 93 | +            updated=True,  | 
 | 94 | +        )  | 
0 commit comments