Skip to content

Commit 10eac2d

Browse files
committed
Arm Backend: Add support for ELU.default operator
Signed-off-by: Agrima Khare <[email protected]> Change-Id: I032414e7454d5e2cada05b788e9eed0f7b2dc97c
1 parent 275adee commit 10eac2d

File tree

8 files changed

+254
-0
lines changed

8 files changed

+254
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .cast_to_int32_pass import CastToInt32Pass # noqa
1616
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
1717
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
18+
from .convert_elu_params import ConvertELUParamsPass # noqa
1819
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
1920
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
2021
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
@@ -32,6 +33,7 @@
3233
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
3334
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3435
from .decompose_div_pass import DecomposeDivPass # noqa
36+
from .decompose_elu_pass import DecomposeEluPass # noqa
3537
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
3638
from .decompose_gelu_pass import DecomposeGeluPass # noqa
3739
from .decompose_grouped_conv import DecomposeGroupedConv # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ComputeConstantOpsAOT,
2020
Conv1dUnsqueezePass,
2121
ConvertAnyDefaultDimDimsPass,
22+
ConvertELUParamsPass,
2223
ConvertExpandCopyToRepeatPass,
2324
ConvertFullLikeToFullPass,
2425
ConvertIntPowToMuls,
@@ -37,6 +38,7 @@
3738
DecomposeBatchNormNoStatsPass,
3839
DecomposeCosineSimilarityPass,
3940
DecomposeDivPass,
41+
DecomposeEluPass,
4042
DecomposeEmbeddingPass,
4143
DecomposeGeluPass,
4244
DecomposeGroupedConv,
@@ -127,6 +129,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
127129
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
128130
self.add_pass(AnnotateDecomposedMatmulPass())
129131
self.add_pass(QuantizeOperatorArguments())
132+
self.add_pass(ConvertELUParamsPass())
130133
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
131134
self.add_pass(RetraceFoldedDtypesPass())
132135
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
@@ -171,6 +174,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
171174
self.add_pass(DecomposeAtanPass())
172175
self.add_pass(DecomposeAtanhPass())
173176
self.add_pass(DecomposeAddmmPass())
177+
self.add_pass(DecomposeEluPass())
174178
self.add_pass(ConvertIntPowToMuls())
175179
self.add_pass(CastBoolToInt8Pass())
176180
self.add_pass(DecomposeSinhPass())
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.arm_pass_utils import create_node
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class ConvertELUParamsPass(ExportPass):
13+
"""
14+
Pass to convert the input_scale kwarg of ELU operator from float to
15+
int.
16+
17+
It has been set to 2 as the outputs seem to stay the same regardless of what
18+
the value of input_scale is, as long as that value is not 1.
19+
"""
20+
21+
def call(self, graph_module: torch.fx.GraphModule):
22+
modified_graph = False
23+
graph = graph_module.graph
24+
node_list = graph.find_nodes(
25+
op="call_function", target=exir_ops.edge.aten.elu.default
26+
)
27+
for node in node_list:
28+
with graph.inserting_after(node):
29+
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
30+
replace_node.args = (
31+
node.args[0],
32+
int(node.args[1]) if len(node.args) > 1 else 1,
33+
)
34+
updated_kwargs = dict(node.kwargs)
35+
updated_kwargs["input_scale"] = int(2)
36+
replace_node.kwargs = updated_kwargs
37+
38+
node.replace_all_uses_with(replace_node)
39+
graph.erase_node(node)
40+
41+
modified_graph = True
42+
if modified_graph:
43+
graph_module.recompile()
44+
graph_module = super().call(graph_module).graph_module
45+
46+
return PassResult(graph_module, modified_graph)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
edge_elu_ops = (exir_ops.edge.aten.elu.default,)
11+
aten_elu_ops = (torch.ops.aten.elu.default, torch.ops.aten.elu_.default)
12+
13+
14+
def get_elu_decomposition(op) -> tuple:
15+
"""
16+
Returns the decomposition of the given aten.elu operation into
17+
its equivalent TOSA-supported operations
18+
19+
This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
20+
is:
21+
elu(x, y) → where(greater_or_eq(x, 0), (exp(x)-1), x)
22+
23+
Returns:
24+
A tuple (exp_op, sub_op, ge_op, where_op) corresponding to the appropriate operator
25+
overloads for the input op.
26+
27+
Raises:
28+
RuntimeError: If the provided operator is not a supported elu variant.
29+
"""
30+
31+
if op in edge_elu_ops:
32+
return (
33+
exir_ops.edge.aten.add.Scalar,
34+
exir_ops.edge.aten.exp.default,
35+
exir_ops.edge.aten.ge.Scalar,
36+
exir_ops.edge.aten.where.self,
37+
exir_ops.edge.aten.mul.Scalar,
38+
)
39+
40+
if op in aten_elu_ops:
41+
return (
42+
torch.ops.aten.add.Scalar,
43+
torch.ops.aten.exp.default,
44+
torch.ops.aten.ge.Scalar,
45+
torch.ops.aten.where.self,
46+
torch.ops.aten.mul.Scalar,
47+
)
48+
49+
raise RuntimeError(f"Can't get elu decomposition for op {op}")
50+
51+
52+
class DecomposeEluPass(ArmPass):
53+
"""
54+
A transformation pass that decomposes unsupported 'aten.elu' operations
55+
into a combination of supported TOSA-equivalent operations.
56+
57+
Since TOSA does not provide a native ELU operator, this pass rewrites:
58+
elu(x) → where(greater_or_eq(x, 0), (alpha*(exp(x)-1)), x)
59+
60+
Supported input ops:
61+
- aten.elu(x)
62+
- aten.elu_(x)
63+
- exir_ops.edge.aten.elu.Tensor(x)
64+
65+
These are replaced with:
66+
- aten.exp or exir_ops.edge.aten.exp
67+
- aten.sub.Scalar or exir_ops.edge.aten.sub.Scalar
68+
- aten.ge.Scalar or exir_ops.edge.aten.ge.Scalar
69+
- aten.where.self or exir_ops.edge.aten.where.self
70+
- aten.mul.Scalar or exir_ops.edge.aten.mul.Scalar
71+
"""
72+
73+
def call_operator(self, op, args, kwargs, meta):
74+
if op not in (edge_elu_ops + aten_elu_ops):
75+
return super().call_operator(op, args, kwargs, meta, updated=False)
76+
77+
(
78+
add_op,
79+
exp_op,
80+
ge_op,
81+
where_op,
82+
mul_op,
83+
) = get_elu_decomposition(op)
84+
85+
input = args[0]
86+
alpha = int(args[1]) if len(args) > 1 else 1
87+
88+
exp_node = super().call_operator(exp_op, (input,), {}, meta, updated=True)
89+
sub_node = super().call_operator(
90+
add_op, (exp_node, -1.0), {}, meta, updated=True
91+
)
92+
mul_node = super().call_operator(
93+
mul_op, (sub_node, alpha), {}, meta, updated=True
94+
)
95+
ge_node = super().call_operator(ge_op, (input, 0.0), {}, meta, updated=True)
96+
where_node = super().call_operator(
97+
where_op, (ge_node, input, mul_node), {}, meta, updated=True
98+
)
99+
100+
return where_node

backends/arm/_passes/insert_table_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class TableOps:
6464
special_table_ops: Set[EdgeOpOverload] = {
6565
exir_ops.edge.aten.pow.Tensor_Scalar,
6666
exir_ops.edge.aten.gelu.default,
67+
exir_ops.edge.aten.elu.default,
6768
}
6869

6970
def __init__(self, exported_program: ExportedProgram):
@@ -97,6 +98,11 @@ def __getitem__(self, node: Node):
9798
return lambda x: torch.nn.functional.gelu(
9899
x, approximate=approximate
99100
).flatten()
101+
case exir_ops.edge.aten.elu.default:
102+
input_alpha = cast(int, node.args[1]) if len(node.args) > 1 else 1
103+
return lambda x: torch.nn.functional.elu(
104+
x, alpha=input_alpha
105+
).flatten()
100106
case _:
101107
# Op must be handled if it's inside self.special_ops
102108
raise AssertionError("Unhandled table operation")

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def is_node_supported(
258258
exir_ops.edge.aten.atanh.default,
259259
exir_ops.edge.aten.addmm.default,
260260
exir_ops.edge.aten.masked_fill.Scalar,
261+
exir_ops.edge.aten.elu.default,
261262
]
262263

263264
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def _match_pattern(
198198
torch.ops.aten.ceil.default,
199199
torch.ops.aten.erf.default,
200200
torch.ops.aten.exp.default,
201+
torch.ops.aten.elu.default,
201202
torch.ops.aten.floor.default,
202203
torch.ops.aten.log.default,
203204
torch.ops.aten.reciprocal.default,

backends/arm/test/ops/test_elu.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn as nn
10+
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineBI,
14+
EthosU85PipelineBI,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
test_data_suite = {
20+
# (test_name, test_data)
21+
"zeros_default": lambda: (1.0, torch.zeros(1, 10, 10, 10)),
22+
"ones_default": lambda: (1.0, torch.ones(10, 10, 10)),
23+
"rand_default": lambda: (1.0, torch.rand(10, 10) - 0.5),
24+
"randn_pos_default": lambda: (1.0, torch.randn(1, 2, 3, 3) + 10),
25+
"randn_neg_default": lambda: (1.0, torch.randn(2, 4, 3) - 10),
26+
"ramp_default": lambda: (1.0, torch.arange(-16, 16, 0.2)),
27+
"large_pos_default": lambda: (1.0, torch.randn(3, 3) * 1e6 + 1e7),
28+
"large_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e5, 1e8)),
29+
"small_pos_default": lambda: (1.0, torch.empty(5).uniform_(1e-8, 1e-5)),
30+
"small_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e-8, 1e-5)),
31+
"zeros_custom": lambda: (2.0, torch.zeros(1, 10, 10, 10)),
32+
"ones_custom": lambda: (2.0, torch.ones(10, 10, 10)),
33+
"rand_custom": lambda: (2.0, torch.rand(10, 10) - 0.5),
34+
"randn_pos_custom": lambda: (2.0, torch.randn(1, 3, 3) + 10),
35+
"randn_neg_custom": lambda: (2.0, torch.randn(1, 2, 4, 3) - 10),
36+
"ramp_custom": lambda: (2.0, torch.arange(-16, 16, 0.2)),
37+
"large_pos_custom": lambda: (2.0, torch.randn(3, 3) * 1e6 + 1e7),
38+
"large_neg_custom": lambda: (2, -torch.empty(5).uniform_(1e5, 1e8)),
39+
"small_pos_custom": lambda: (2.0, torch.empty(5).uniform_(1e-8, 1e-5)),
40+
"small_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e-8, 1e-5)),
41+
}
42+
43+
44+
class Elu(nn.Module):
45+
aten_op = "torch.ops.aten.elu.default"
46+
exir_op = "executorch_exir_dialects_edge__ops_aten__elu_default"
47+
48+
def __init__(self, input_alpha: float = 1.0):
49+
super().__init__()
50+
self.elu = torch.nn.ELU(alpha=input_alpha)
51+
52+
def forward(self, input_: torch.Tensor):
53+
return self.elu(input_)
54+
55+
56+
input_t1 = Tuple[torch.Tensor]
57+
58+
59+
@common.parametrize("test_module", test_data_suite)
60+
def test_elu_tosa_MI(test_module: input_t1):
61+
alpha, test_data = test_module()
62+
pipeline = TosaPipelineMI[input_t1](
63+
Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op
64+
)
65+
pipeline.run()
66+
67+
68+
@common.parametrize("test_module", test_data_suite)
69+
def test_elu_tosa_BI(test_module: input_t1):
70+
alpha, test_data = test_module()
71+
pipeline = TosaPipelineBI[input_t1](
72+
Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op
73+
)
74+
pipeline.run()
75+
76+
77+
@common.XfailIfNoCorstone300
78+
@common.parametrize("test_module", test_data_suite)
79+
def test_elu_u55_BI(test_module: input_t1):
80+
alpha, test_data = test_module()
81+
pipeline = EthosU55PipelineBI[input_t1](
82+
Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op
83+
)
84+
pipeline.run()
85+
86+
87+
@common.XfailIfNoCorstone320
88+
@common.parametrize("test_module", test_data_suite)
89+
def test_elu_u85_BI(test_module: input_t1):
90+
alpha, test_data = test_module()
91+
pipeline = EthosU85PipelineBI[input_t1](
92+
Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op
93+
)
94+
pipeline.run()

0 commit comments

Comments
 (0)