Skip to content

Arm Backend: Add support for expm1.default #13274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
from .decompose_gelu_pass import DecomposeGeluPass # noqa
from .decompose_glu_pass import DecomposeGluPass # noqa
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
DecomposeCosineSimilarityPass,
DecomposeDivPass,
DecomposeEmbeddingPass,
DecomposeExpm1Pass,
DecomposeGeluPass,
DecomposeGluPass,
DecomposeGroupedConv,
Expand Down Expand Up @@ -164,6 +165,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
return self._transform(exported_program.graph_module)

def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(DecomposeExpm1Pass())
self.add_pass(DecomposeMaskedFill())
self.add_pass(DecomposeRoundPass())
self.add_pass(DecomposeAcoshPass())
Expand Down
135 changes: 135 additions & 0 deletions backends/arm/_passes/decompose_expm1_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops


edge_expm1_ops = (exir_ops.edge.aten.expm1.default,) # MI case


def _get_expm1_decomposition(op) -> tuple:
"""
Returns the decomposition of the given aten.expm1 operation into
its equivalent TOSA-supported operations

This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
is:
expm1(x) → where(and(ge(x, -0.35), le(x, 0.35)), {taylor_series_expansion}, (exp(x)-1))

where {taylor_series_expansion} = x + (x^2/2) + (x^3/6) + (x^4/24)

Returns:
A tuple (op_pow, op_div, op_add, op_exp, op_sub, op_ge, op_where, op_le, op_and)
corresponding to the appropriate operator overloads for the input op.

Raises:
RuntimeError: If the provided operator is not a supported elu variant.
"""
if op in edge_expm1_ops:
return (
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.div.Scalar,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.sub.Scalar,
exir_ops.edge.aten.ge.Scalar,
exir_ops.edge.aten.where.self,
exir_ops.edge.aten.le.Scalar,
exir_ops.edge.aten.logical_and.default,
)

raise RuntimeError(f"Can't get expm1 decomposition for op {op}")


class DecomposeExpm1Pass(ArmPass):
"""
A transformation pass that decomposes unsupported 'aten.expm1' operations
into a combination of supported TOSA-equivalent operations.

Since TOSA does not provide a native expm1 operator, this pass rewrites:
expm1(x) → where(and(ge(x, -0.35), le(x, 0.35)), {taylor_series_expansion}, (exp(x)-1))
where {taylor_series_expansion} = x + (x^2/2) + (x^3/6) + (x^4/24)

Supported input ops:
- exir_ops.edge.aten.expm1.default(x)

These are replaced with:
- exir_ops.edge.aten.pow.Tensor_Scalar,
- exir_ops.edge.aten.div.Scalar,
- exir_ops.edge.aten.add.Tensor,
- exir_ops.edge.aten.exp.default,
- exir_ops.edge.aten.sub.Scalar,
- exir_ops.edge.aten.ge.Scalar,
- exir_ops.edge.aten.where.self,
- exir_ops.edge.aten.le.Scalar,
- exir_ops.edge.aten.logical_and.default
"""

def call_operator(self, op, args, kwargs, meta):
if op not in edge_expm1_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)

(
op_pow,
op_div,
op_add,
op_exp,
op_sub,
op_ge,
op_where,
op_le,
op_and,
) = _get_expm1_decomposition(op)

input = args[0]

cutlo = -0.35
cuthi = 0.35

taylor_term_2_numerator = super().call_operator(
op_pow, (input, 2), {}, meta, updated=False
)
taylor_term_3_numerator = super().call_operator(
op_pow, (input, 3), {}, meta, updated=False
)
taylor_term_4_numerator = super().call_operator(
op_pow, (input, 4), {}, meta, updated=False
)

taylor_term_2 = super().call_operator(
op_div, (taylor_term_2_numerator, 2), {}, meta, updated=False
)
taylor_term_3 = super().call_operator(
op_div, (taylor_term_3_numerator, 6), {}, meta, updated=False
)
taylor_term_4 = super().call_operator(
op_div, (taylor_term_4_numerator, 24), {}, meta, updated=False
)

add_terms_1_2 = super().call_operator(
op_add, (input, taylor_term_2), {}, meta, updated=False
)
add_term_3 = super().call_operator(
op_add, (add_terms_1_2, taylor_term_3), {}, meta, updated=False
)
taylor_expansion = super().call_operator(
op_add, (add_term_3, taylor_term_4), {}, meta, updated=False
)

decomp_exp = super().call_operator(op_exp, (input,), {}, meta, updated=False)
decomp_sub = super().call_operator(
op_sub, (decomp_exp, 1.0), {}, meta, updated=False
)

ge = super().call_operator(op_ge, (input, cutlo), {}, meta, updated=False)
le = super().call_operator(op_le, (input, cuthi), {}, meta, updated=False)

cond_and = super().call_operator(op_and, (ge, le), {}, meta, updated=False)
where = super().call_operator(
op_where, (cond_and, taylor_expansion, decomp_sub), {}, meta, updated=True
)

return where
1 change: 1 addition & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TableOps:
exir_ops.edge.aten.ceil.default: torch.ceil,
exir_ops.edge.aten.erf.default: torch.erf,
exir_ops.edge.aten.exp.default: torch.exp,
exir_ops.edge.aten.expm1.default: torch.expm1,
exir_ops.edge.aten.floor.default: torch.floor,
exir_ops.edge.aten.log.default: torch.log,
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def is_node_supported(
exir_ops.edge.aten.eq.Scalar,
exir_ops.edge.aten.erf.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.expm1.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def _match_pattern(
torch.ops.aten.ceil.default,
torch.ops.aten.erf.default,
torch.ops.aten.exp.default,
torch.ops.aten.expm1.default,
torch.ops.aten.floor.default,
torch.ops.aten.log.default,
torch.ops.aten.reciprocal.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/scripts/parse_test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CUSTOM_EDGE_OPS = [
"linspace.default",
"eye.default",
"expm1.default",
"vector_norm.default",
"hardsigmoid.default",
"hardswish.default",
Expand Down
113 changes: 113 additions & 0 deletions backends/arm/test/ops/test_expm1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)

aten_op = "torch.ops.aten.expm1.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_expm1_default"

input_t1 = Tuple[torch.Tensor]

test_data_suite = {
"zeroes": torch.zeros(1, 10, 10, 10),
"ones": torch.ones(10, 2, 3),
"rand": torch.rand(10, 10) - 0.5,
"near_zero": torch.randn(100) * 0.01,
"taylor_small": torch.empty(5).uniform_(
-0.35, 0.35
), # test cases for taylor series expansion
"randn_large_pos": torch.randn(10) + 10,
"randn_large_neg": torch.randn(10) - 10,
"ramp": torch.arange(-16, 16, 0.2),
}


class Expm1(torch.nn.Module):

def forward(self, x: torch.Tensor):
return torch.expm1(x)


@common.parametrize("test_data", test_data_suite)
def test_expm1_tosa_FP(test_data: Tuple):
pipeline = TosaPipelineFP[input_t1](
Expm1(),
(test_data,),
aten_op=aten_op,
exir_op=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_expm1_tosa_INT(test_data: Tuple):
pipeline = TosaPipelineINT[input_t1](
Expm1(),
(test_data,),
aten_op=aten_op,
exir_op=exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone300
@common.parametrize("test_data", test_data_suite)
def test_expm1_u55_INT(test_data: Tuple):
pipeline = EthosU55PipelineINT[input_t1](
Expm1(),
(test_data,),
aten_ops=aten_op,
exir_ops=exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone320
@common.parametrize("test_data", test_data_suite)
def test_expm1_u85_INT(test_data: Tuple):
pipeline = EthosU85PipelineINT[input_t1](
Expm1(),
(test_data,),
aten_ops=aten_op,
exir_ops=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.SkipIfNoModelConverter
def test_expm1_vgf_FP(test_data: Tuple):
pipeline = VgfPipeline[input_t1](
Expm1(),
(test_data,),
aten_op,
exir_op,
tosa_version="TOSA-1.0+FP",
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.SkipIfNoModelConverter
def test_expm1_vgf_INT(test_data: Tuple):
pipeline = VgfPipeline[input_t1](
Expm1(),
(test_data,),
aten_op,
exir_op,
tosa_version="TOSA-1.0+INT",
)
pipeline.run()
Loading