Skip to content

Commit a42423c

Browse files
authored
Arm backend: Add support for ELU.default operator (#13683)
Decomposes elu into other operators/ lookup table for FP and INT cases.
1 parent 64db0c1 commit a42423c

File tree

8 files changed

+286
-0
lines changed

8 files changed

+286
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .cast_to_int32_pass import CastToInt32Pass # noqa
1515
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
1616
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
17+
from .convert_elu_params import ConvertELUParamsPass # noqa
1718
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
1819
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
1920
from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa
@@ -36,6 +37,7 @@
3637
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3738
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
3839
from .decompose_div_pass import DecomposeDivPass # noqa
40+
from .decompose_elu_pass import DecomposeEluPass # noqa
3941
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
4042
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
4143
from .decompose_gelu_pass import DecomposeGeluPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ComputeConstantOpsAOT,
1919
Conv1dUnsqueezePass,
2020
ConvertAnyDefaultDimDimsPass,
21+
ConvertELUParamsPass,
2122
ConvertExpandCopyToRepeatPass,
2223
ConvertFullLikeToFullPass,
2324
ConvertInt64ConstOpsToInt32Pass,
@@ -41,6 +42,7 @@
4142
DecomposeCosineSimilarityPass,
4243
DecomposeCumsumPass,
4344
DecomposeDivPass,
45+
DecomposeEluPass,
4446
DecomposeEmbeddingPass,
4547
DecomposeExpm1Pass,
4648
DecomposeGeluPass,
@@ -135,6 +137,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
135137
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
136138
self.add_pass(AnnotateDecomposedMatmulPass())
137139
self.add_pass(QuantizeOperatorArguments())
140+
self.add_pass(ConvertELUParamsPass())
138141
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
139142
self.add_pass(RetraceFoldedDtypesPass())
140143
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
@@ -183,6 +186,8 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
183186
self.add_pass(DecomposeAtanPass())
184187
self.add_pass(DecomposeAtanhPass())
185188
self.add_pass(DecomposeAddmmPass())
189+
self.add_pass(DecomposeEluPass())
190+
self.add_pass(DecomposeExpm1Pass())
186191
self.add_pass(ConvertIntPowToMuls())
187192
self.add_pass(CastBoolToInt8Pass())
188193
self.add_pass(DecomposeSinhPass())
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.arm_pass_utils import create_node
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class ConvertELUParamsPass(ExportPass):
13+
"""
14+
Pass to convert the input_scale kwarg of ELU operator from float to
15+
int.
16+
17+
It has been set to 2 as the outputs seem to stay the same regardless of what
18+
the value of input_scale is, as long as that value is not 1.
19+
"""
20+
21+
def call(self, graph_module: torch.fx.GraphModule):
22+
modified_graph = False
23+
graph = graph_module.graph
24+
node_list = graph.find_nodes(
25+
op="call_function", target=exir_ops.edge.aten.elu.default
26+
)
27+
for node in node_list:
28+
with graph.inserting_after(node):
29+
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
30+
old_args = list(node.args)
31+
32+
alpha = old_args[1] if len(old_args) > 1 else 1.0
33+
scale = 1.0
34+
input_scale = 2.0
35+
36+
replace_node.args = (old_args[0],)
37+
38+
updated_kwargs = dict(node.kwargs)
39+
updated_kwargs["alpha"] = int(alpha)
40+
updated_kwargs["scale"] = int(scale)
41+
updated_kwargs["input_scale"] = int(input_scale)
42+
43+
replace_node.kwargs = updated_kwargs
44+
45+
node.replace_all_uses_with(replace_node)
46+
graph.erase_node(node)
47+
48+
modified_graph = True
49+
if modified_graph:
50+
graph_module.recompile()
51+
graph_module = super().call(graph_module).graph_module
52+
53+
return PassResult(graph_module, modified_graph)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
9+
edge_elu_ops = (exir_ops.edge.aten.elu.default,)
10+
11+
12+
def get_elu_decomposition(op) -> tuple:
13+
"""
14+
Returns the decomposition of the given aten.elu operation into
15+
its equivalent TOSA-supported operations
16+
17+
This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
18+
is:
19+
elu(x, y) → where(greater_or_eq(x, 0), (exp(x)-1), x)
20+
21+
Returns:
22+
A tuple (expm1_op, ge_op, where_op, mul_op) corresponding to the appropriate operator
23+
overloads for the input op.
24+
25+
Raises:
26+
RuntimeError: If the provided operator is not a supported elu variant.
27+
"""
28+
29+
if op in edge_elu_ops:
30+
return (
31+
exir_ops.edge.aten.expm1.default,
32+
exir_ops.edge.aten.ge.Scalar,
33+
exir_ops.edge.aten.where.self,
34+
exir_ops.edge.aten.mul.Scalar,
35+
)
36+
37+
raise RuntimeError(f"Can't get elu decomposition for op {op}")
38+
39+
40+
class DecomposeEluPass(ArmPass):
41+
"""
42+
A transformation pass that decomposes unsupported 'aten.elu' operations
43+
into a combination of supported TOSA-equivalent operations.
44+
45+
Since TOSA does not provide a native ELU operator, this pass rewrites:
46+
elu(x) → where(greater_or_eq(x, 0), (alpha*(exp(x)-1)), x)
47+
48+
Supported input ops:
49+
- exir_ops.edge.aten.elu.Tensor(x)
50+
51+
These are replaced with:
52+
- exir_ops.edge.aten.expm1.default
53+
- exir_ops.edge.aten.ge.Scalar
54+
- exir_ops.edge.aten.where.self
55+
- exir_ops.edge.aten.mul.Scalar
56+
"""
57+
58+
def call_operator(self, op, args, kwargs, meta):
59+
if op not in edge_elu_ops:
60+
return super().call_operator(op, args, kwargs, meta, updated=False)
61+
62+
(
63+
expm1_op,
64+
ge_op,
65+
where_op,
66+
mul_op,
67+
) = get_elu_decomposition(op)
68+
69+
input = args[0]
70+
alpha = args[1] if len(args) > 1 else 1.0
71+
72+
if alpha == 0:
73+
relu_op = exir_ops.edge.aten.relu.default
74+
return super().call_operator(relu_op, (input,), {}, meta, updated=True)
75+
76+
expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True)
77+
mul_node = super().call_operator(
78+
mul_op, (expm1_node, alpha), {}, meta, updated=True
79+
)
80+
ge_node = super().call_operator(ge_op, (input, 0.0), {}, meta, updated=True)
81+
where_node = super().call_operator(
82+
where_op, (ge_node, input, mul_node), {}, meta, updated=True
83+
)
84+
85+
return where_node

backends/arm/_passes/insert_table_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class TableOps:
5959
special_table_ops: Set[EdgeOpOverload] = {
6060
exir_ops.edge.aten.pow.Tensor_Scalar,
6161
exir_ops.edge.aten.gelu.default,
62+
exir_ops.edge.aten.elu.default,
6263
}
6364

6465
def __init__(self, exported_program: ExportedProgram):
@@ -92,6 +93,11 @@ def __getitem__(self, node: Node):
9293
return lambda x: torch.nn.functional.gelu(
9394
x, approximate=approximate
9495
).flatten()
96+
case exir_ops.edge.aten.elu.default:
97+
input_alpha = cast(int, node.kwargs["alpha"])
98+
return lambda x: torch.nn.functional.elu(
99+
x, alpha=input_alpha
100+
).flatten()
95101
case _:
96102
# Op must be handled if it's inside self.special_ops
97103
raise AssertionError("Unhandled table operation")

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def is_node_supported(
263263
exir_ops.edge.aten.glu.default,
264264
exir_ops.edge.aten.logit.default,
265265
exir_ops.edge.aten.acos.default,
266+
exir_ops.edge.aten.elu.default,
266267
]
267268

268269
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def _match_pattern(
266266
torch.ops.aten.erf.default,
267267
torch.ops.aten.exp.default,
268268
torch.ops.aten.expm1.default,
269+
torch.ops.aten.elu.default,
269270
torch.ops.aten.floor.default,
270271
torch.ops.aten.log.default,
271272
torch.ops.aten.reciprocal.default,

backends/arm/test/ops/test_elu.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn as nn
10+
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineINT,
14+
EthosU85PipelineINT,
15+
TosaPipelineFP,
16+
TosaPipelineINT,
17+
VgfPipeline,
18+
)
19+
20+
test_data_suite = {
21+
# (test_name, test_data)
22+
"zeros_default": lambda: (1.0, torch.zeros(1, 10, 10, 10)),
23+
"ones_default": lambda: (1.0, torch.ones(10, 10, 10)),
24+
"rand_default": lambda: (1.0, torch.rand(10, 10) - 0.5),
25+
"randn_pos_default": lambda: (1.0, torch.randn(1, 2, 3, 3) + 10),
26+
"randn_neg_default": lambda: (1.0, torch.randn(2, 4, 3) - 10),
27+
"ramp_default": lambda: (1.0, torch.arange(-16, 16, 0.2)),
28+
"large_pos_default": lambda: (1.0, torch.randn(3, 3) * 1e6 + 1e7),
29+
"large_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e5, 1e8)),
30+
"small_pos_default": lambda: (1.0, torch.empty(5).uniform_(1e-8, 1e-5)),
31+
"small_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e-8, 1e-5)),
32+
"zeros_custom": lambda: (2.0, torch.zeros(1, 10, 10, 10)),
33+
"ones_custom": lambda: (2.0, torch.ones(10, 10, 10)),
34+
"rand_custom": lambda: (2.0, torch.rand(10, 10) - 0.5),
35+
"randn_pos_custom": lambda: (2.0, torch.randn(1, 3, 3) + 10),
36+
"randn_neg_custom": lambda: (2.0, torch.randn(1, 2, 4, 3) - 10),
37+
"ramp_custom": lambda: (2.0, torch.arange(-16, 16, 0.2)),
38+
"large_pos_custom": lambda: (2.0, torch.randn(3, 3) * 1e6 + 1e7),
39+
"large_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e5, 1e8)),
40+
"small_pos_custom": lambda: (2.0, torch.empty(5).uniform_(1e-8, 1e-5)),
41+
"small_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e-8, 1e-5)),
42+
"zeros_zero": lambda: (0.0, torch.zeros(1, 10, 10, 10)),
43+
"ones_zero": lambda: (0.0, torch.ones(10, 10, 10)),
44+
"rand_zero": lambda: (0.0, torch.rand(10, 10) - 0.5),
45+
"randn_pos_zero": lambda: (0.0, torch.randn(1, 3, 3) + 10),
46+
"randn_neg_zero": lambda: (0.0, torch.randn(1, 2, 4, 3) - 10),
47+
"ramp_zero": lambda: (0.0, torch.arange(-16, 16, 0.2)),
48+
"large_pos_zero": lambda: (0.0, torch.randn(3, 3) * 1e6 + 1e7),
49+
"large_neg_zero": lambda: (0.0, -torch.empty(5).uniform_(1e5, 1e8)),
50+
"small_pos_zero": lambda: (0.0, torch.empty(5).uniform_(1e-8, 1e-5)),
51+
"small_neg_zero": lambda: (0.0, -torch.empty(5).uniform_(1e-8, 1e-5)),
52+
}
53+
54+
55+
class Elu(nn.Module):
56+
aten_op = "torch.ops.aten.elu.default"
57+
exir_op = "executorch_exir_dialects_edge__ops_aten__elu_default"
58+
59+
def __init__(self, input_alpha: float = 1.0):
60+
super().__init__()
61+
self.elu = torch.nn.ELU(alpha=input_alpha)
62+
63+
def forward(self, input_: torch.Tensor):
64+
return self.elu(input_)
65+
66+
67+
input_t1 = Tuple[torch.Tensor]
68+
69+
70+
@common.parametrize("test_module", test_data_suite)
71+
def test_elu_tosa_FP(test_module: input_t1):
72+
alpha, test_data = test_module()
73+
pipeline = TosaPipelineFP[input_t1](
74+
Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op
75+
)
76+
pipeline.run()
77+
78+
79+
@common.parametrize("test_module", test_data_suite)
80+
def test_elu_tosa_INT(test_module: input_t1):
81+
alpha, test_data = test_module()
82+
pipeline = TosaPipelineINT[input_t1](
83+
Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op
84+
)
85+
pipeline.run()
86+
87+
88+
@common.XfailIfNoCorstone300
89+
@common.parametrize("test_module", test_data_suite)
90+
def test_elu_u55_INT(test_module: input_t1):
91+
alpha, test_data = test_module()
92+
pipeline = EthosU55PipelineINT[input_t1](
93+
Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op
94+
)
95+
pipeline.run()
96+
97+
98+
@common.XfailIfNoCorstone320
99+
@common.parametrize("test_module", test_data_suite)
100+
def test_elu_u85_INT(test_module: input_t1):
101+
alpha, test_data = test_module()
102+
pipeline = EthosU85PipelineINT[input_t1](
103+
Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op
104+
)
105+
pipeline.run()
106+
107+
108+
@common.SkipIfNoModelConverter
109+
@common.parametrize("test_module", test_data_suite)
110+
def test_elu_vgf_FP(test_module: input_t1):
111+
alpha, test_data = test_module()
112+
pipeline = VgfPipeline[input_t1](
113+
Elu(alpha),
114+
(test_data,),
115+
aten_op=Elu.aten_op,
116+
exir_op=Elu.exir_op,
117+
tosa_version="TOSA-1.0+FP",
118+
)
119+
pipeline.run()
120+
121+
122+
@common.SkipIfNoModelConverter
123+
@common.parametrize("test_module", test_data_suite)
124+
def test_elu_vgf_INT(test_module: input_t1):
125+
alpha, test_data = test_module()
126+
pipeline = VgfPipeline[input_t1](
127+
Elu(alpha),
128+
(test_data,),
129+
aten_op=Elu.aten_op,
130+
exir_op=Elu.exir_op,
131+
tosa_version="TOSA-1.0+INT",
132+
)
133+
pipeline.run()

0 commit comments

Comments
 (0)