Skip to content

Commit 846bc1e

Browse files
authored
Arm Backend: Add support for expm1.default (#13274)
Decompose expm1 into other operators or use the Taylor series expansion when input values are close to 0. Signed-off-by: Agrima Khare <[email protected]>
1 parent 17047ea commit 846bc1e

File tree

8 files changed

+255
-0
lines changed

8 files changed

+255
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3636
from .decompose_div_pass import DecomposeDivPass # noqa
3737
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
38+
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
3839
from .decompose_gelu_pass import DecomposeGeluPass # noqa
3940
from .decompose_glu_pass import DecomposeGluPass # noqa
4041
from .decompose_grouped_conv import DecomposeGroupedConv # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
DecomposeCosineSimilarityPass,
4141
DecomposeDivPass,
4242
DecomposeEmbeddingPass,
43+
DecomposeExpm1Pass,
4344
DecomposeGeluPass,
4445
DecomposeGluPass,
4546
DecomposeGroupedConv,
@@ -164,6 +165,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
164165
return self._transform(exported_program.graph_module)
165166

166167
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
168+
self.add_pass(DecomposeExpm1Pass())
167169
self.add_pass(DecomposeMaskedFill())
168170
self.add_pass(DecomposeRoundPass())
169171
self.add_pass(DecomposeAcoshPass())
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
9+
10+
edge_expm1_ops = (exir_ops.edge.aten.expm1.default,) # MI case
11+
12+
13+
def _get_expm1_decomposition(op) -> tuple:
14+
"""
15+
Returns the decomposition of the given aten.expm1 operation into
16+
its equivalent TOSA-supported operations
17+
18+
This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
19+
is:
20+
expm1(x) → where(and(ge(x, -0.35), le(x, 0.35)), {taylor_series_expansion}, (exp(x)-1))
21+
22+
where {taylor_series_expansion} = x + (x^2/2) + (x^3/6) + (x^4/24)
23+
24+
Returns:
25+
A tuple (op_pow, op_div, op_add, op_exp, op_sub, op_ge, op_where, op_le, op_and)
26+
corresponding to the appropriate operator overloads for the input op.
27+
28+
Raises:
29+
RuntimeError: If the provided operator is not a supported elu variant.
30+
"""
31+
if op in edge_expm1_ops:
32+
return (
33+
exir_ops.edge.aten.pow.Tensor_Scalar,
34+
exir_ops.edge.aten.div.Scalar,
35+
exir_ops.edge.aten.add.Tensor,
36+
exir_ops.edge.aten.exp.default,
37+
exir_ops.edge.aten.sub.Scalar,
38+
exir_ops.edge.aten.ge.Scalar,
39+
exir_ops.edge.aten.where.self,
40+
exir_ops.edge.aten.le.Scalar,
41+
exir_ops.edge.aten.logical_and.default,
42+
)
43+
44+
raise RuntimeError(f"Can't get expm1 decomposition for op {op}")
45+
46+
47+
class DecomposeExpm1Pass(ArmPass):
48+
"""
49+
A transformation pass that decomposes unsupported 'aten.expm1' operations
50+
into a combination of supported TOSA-equivalent operations.
51+
52+
Since TOSA does not provide a native expm1 operator, this pass rewrites:
53+
expm1(x) → where(and(ge(x, -0.35), le(x, 0.35)), {taylor_series_expansion}, (exp(x)-1))
54+
where {taylor_series_expansion} = x + (x^2/2) + (x^3/6) + (x^4/24)
55+
56+
Supported input ops:
57+
- exir_ops.edge.aten.expm1.default(x)
58+
59+
These are replaced with:
60+
- exir_ops.edge.aten.pow.Tensor_Scalar,
61+
- exir_ops.edge.aten.div.Scalar,
62+
- exir_ops.edge.aten.add.Tensor,
63+
- exir_ops.edge.aten.exp.default,
64+
- exir_ops.edge.aten.sub.Scalar,
65+
- exir_ops.edge.aten.ge.Scalar,
66+
- exir_ops.edge.aten.where.self,
67+
- exir_ops.edge.aten.le.Scalar,
68+
- exir_ops.edge.aten.logical_and.default
69+
"""
70+
71+
def call_operator(self, op, args, kwargs, meta):
72+
if op not in edge_expm1_ops:
73+
return super().call_operator(op, args, kwargs, meta, updated=False)
74+
75+
(
76+
op_pow,
77+
op_div,
78+
op_add,
79+
op_exp,
80+
op_sub,
81+
op_ge,
82+
op_where,
83+
op_le,
84+
op_and,
85+
) = _get_expm1_decomposition(op)
86+
87+
input = args[0]
88+
89+
cutlo = -0.35
90+
cuthi = 0.35
91+
92+
taylor_term_2_numerator = super().call_operator(
93+
op_pow, (input, 2), {}, meta, updated=False
94+
)
95+
taylor_term_3_numerator = super().call_operator(
96+
op_pow, (input, 3), {}, meta, updated=False
97+
)
98+
taylor_term_4_numerator = super().call_operator(
99+
op_pow, (input, 4), {}, meta, updated=False
100+
)
101+
102+
taylor_term_2 = super().call_operator(
103+
op_div, (taylor_term_2_numerator, 2), {}, meta, updated=False
104+
)
105+
taylor_term_3 = super().call_operator(
106+
op_div, (taylor_term_3_numerator, 6), {}, meta, updated=False
107+
)
108+
taylor_term_4 = super().call_operator(
109+
op_div, (taylor_term_4_numerator, 24), {}, meta, updated=False
110+
)
111+
112+
add_terms_1_2 = super().call_operator(
113+
op_add, (input, taylor_term_2), {}, meta, updated=False
114+
)
115+
add_term_3 = super().call_operator(
116+
op_add, (add_terms_1_2, taylor_term_3), {}, meta, updated=False
117+
)
118+
taylor_expansion = super().call_operator(
119+
op_add, (add_term_3, taylor_term_4), {}, meta, updated=False
120+
)
121+
122+
decomp_exp = super().call_operator(op_exp, (input,), {}, meta, updated=False)
123+
decomp_sub = super().call_operator(
124+
op_sub, (decomp_exp, 1.0), {}, meta, updated=False
125+
)
126+
127+
ge = super().call_operator(op_ge, (input, cutlo), {}, meta, updated=False)
128+
le = super().call_operator(op_le, (input, cuthi), {}, meta, updated=False)
129+
130+
cond_and = super().call_operator(op_and, (ge, le), {}, meta, updated=False)
131+
where = super().call_operator(
132+
op_where, (cond_and, taylor_expansion, decomp_sub), {}, meta, updated=True
133+
)
134+
135+
return where

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class TableOps:
4343
exir_ops.edge.aten.ceil.default: torch.ceil,
4444
exir_ops.edge.aten.erf.default: torch.erf,
4545
exir_ops.edge.aten.exp.default: torch.exp,
46+
exir_ops.edge.aten.expm1.default: torch.expm1,
4647
exir_ops.edge.aten.floor.default: torch.floor,
4748
exir_ops.edge.aten.log.default: torch.log,
4849
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def is_node_supported(
179179
exir_ops.edge.aten.eq.Scalar,
180180
exir_ops.edge.aten.erf.default,
181181
exir_ops.edge.aten.exp.default,
182+
exir_ops.edge.aten.expm1.default,
182183
exir_ops.edge.aten.log.default,
183184
exir_ops.edge.aten.linear.default,
184185
exir_ops.edge.aten.split_with_sizes_copy.default,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def _match_pattern(
265265
torch.ops.aten.ceil.default,
266266
torch.ops.aten.erf.default,
267267
torch.ops.aten.exp.default,
268+
torch.ops.aten.expm1.default,
268269
torch.ops.aten.floor.default,
269270
torch.ops.aten.log.default,
270271
torch.ops.aten.reciprocal.default,

backends/arm/scripts/parse_test_names.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
CUSTOM_EDGE_OPS = [
99
"linspace.default",
1010
"eye.default",
11+
"expm1.default",
1112
"vector_norm.default",
1213
"hardsigmoid.default",
1314
"hardswish.default",

backends/arm/test/ops/test_expm1.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
aten_op = "torch.ops.aten.expm1.default"
20+
exir_op = "executorch_exir_dialects_edge__ops_aten_expm1_default"
21+
22+
input_t1 = Tuple[torch.Tensor]
23+
24+
test_data_suite = {
25+
"zeroes": torch.zeros(1, 10, 10, 10),
26+
"ones": torch.ones(10, 2, 3),
27+
"rand": torch.rand(10, 10) - 0.5,
28+
"near_zero": torch.randn(100) * 0.01,
29+
"taylor_small": torch.empty(5).uniform_(
30+
-0.35, 0.35
31+
), # test cases for taylor series expansion
32+
"randn_large_pos": torch.randn(10) + 10,
33+
"randn_large_neg": torch.randn(10) - 10,
34+
"ramp": torch.arange(-16, 16, 0.2),
35+
}
36+
37+
38+
class Expm1(torch.nn.Module):
39+
40+
def forward(self, x: torch.Tensor):
41+
return torch.expm1(x)
42+
43+
44+
@common.parametrize("test_data", test_data_suite)
45+
def test_expm1_tosa_FP(test_data: Tuple):
46+
pipeline = TosaPipelineFP[input_t1](
47+
Expm1(),
48+
(test_data,),
49+
aten_op=aten_op,
50+
exir_op=exir_op,
51+
)
52+
pipeline.run()
53+
54+
55+
@common.parametrize("test_data", test_data_suite)
56+
def test_expm1_tosa_INT(test_data: Tuple):
57+
pipeline = TosaPipelineINT[input_t1](
58+
Expm1(),
59+
(test_data,),
60+
aten_op=aten_op,
61+
exir_op=exir_op,
62+
)
63+
pipeline.run()
64+
65+
66+
@common.XfailIfNoCorstone300
67+
@common.parametrize("test_data", test_data_suite)
68+
def test_expm1_u55_INT(test_data: Tuple):
69+
pipeline = EthosU55PipelineINT[input_t1](
70+
Expm1(),
71+
(test_data,),
72+
aten_ops=aten_op,
73+
exir_ops=exir_op,
74+
)
75+
pipeline.run()
76+
77+
78+
@common.XfailIfNoCorstone320
79+
@common.parametrize("test_data", test_data_suite)
80+
def test_expm1_u85_INT(test_data: Tuple):
81+
pipeline = EthosU85PipelineINT[input_t1](
82+
Expm1(),
83+
(test_data,),
84+
aten_ops=aten_op,
85+
exir_ops=exir_op,
86+
)
87+
pipeline.run()
88+
89+
90+
@common.parametrize("test_data", test_data_suite)
91+
@common.SkipIfNoModelConverter
92+
def test_expm1_vgf_FP(test_data: Tuple):
93+
pipeline = VgfPipeline[input_t1](
94+
Expm1(),
95+
(test_data,),
96+
aten_op,
97+
exir_op,
98+
tosa_version="TOSA-1.0+FP",
99+
)
100+
pipeline.run()
101+
102+
103+
@common.parametrize("test_data", test_data_suite)
104+
@common.SkipIfNoModelConverter
105+
def test_expm1_vgf_INT(test_data: Tuple):
106+
pipeline = VgfPipeline[input_t1](
107+
Expm1(),
108+
(test_data,),
109+
aten_op,
110+
exir_op,
111+
tosa_version="TOSA-1.0+INT",
112+
)
113+
pipeline.run()

0 commit comments

Comments
 (0)