Skip to content

Commit 00117cf

Browse files
authored
Arm backend: Add asinh decomposition pass and test (#13035)
Add decomposition pass and tests for asinh. Signed-off-by: Emma Kujala <[email protected]>
1 parent 308cad0 commit 00117cf

File tree

7 files changed

+135
-2
lines changed

7 files changed

+135
-2
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
2727
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
2828
from .decompose_asin_pass import DecomposeAsinPass # noqa
29+
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
2930
from .decompose_atan_pass import DecomposeAtanPass # noqa
3031
from .decompose_atanh_pass import DecomposeAtanhPass # noqa
3132
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DecomposeAcoshPass,
3131
DecomposeAdaptiveAvgPool2dPass,
3232
DecomposeAddmmPass,
33+
DecomposeAsinhPass,
3334
DecomposeAsinPass,
3435
DecomposeAtanhPass,
3536
DecomposeAtanPass,
@@ -114,7 +115,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
114115
self.add_pass(
115116
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
116117
)
117-
118118
self.add_pass(ConvertFullLikeToFullPass())
119119
self.add_pass(ConvertToClampPass())
120120
self.add_pass(ConvertMinMaxPass())
@@ -148,7 +148,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
148148
self.add_pass(DecomposeMaxPool2DPass())
149149
self.add_pass(SizeAdjustInputPass())
150150
self.add_pass(DecomposeSelectPass())
151-
152151
self.add_pass(ConvertSqueezesToViewPass())
153152

154153
self.add_pass(FuseViewCopyTransform())
@@ -167,6 +166,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
167166
self.add_pass(DecomposeRoundPass())
168167
self.add_pass(DecomposeAcoshPass())
169168
self.add_pass(DecomposeAsinPass())
169+
self.add_pass(DecomposeAsinhPass())
170170
self.add_pass(DecomposeSqrtPass())
171171
self.add_pass(DecomposeAtanPass())
172172
self.add_pass(DecomposeAtanhPass())
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
# For MI case
13+
edge_asinh_op = (exir_ops.edge.aten.asinh.default,)
14+
15+
16+
class DecomposeAsinhPass(ArmPass):
17+
"""
18+
Decomposes asinh to supported TOSA-operations.
19+
This decomposition is based on the mathematical identity:
20+
asinh(x) = log(x + sqrt(x^2 + 1))
21+
"""
22+
23+
def call_operator(self, op, args, kwargs, meta):
24+
if op not in edge_asinh_op:
25+
return super().call_operator(op, args, kwargs, meta)
26+
27+
log_op, sqrt_op, mul_op, add_op_scalar, add_op = (
28+
exir_ops.edge.aten.log.default,
29+
exir_ops.edge.aten.sqrt.default,
30+
exir_ops.edge.aten.mul.Tensor,
31+
exir_ops.edge.aten.add.Scalar,
32+
exir_ops.edge.aten.add.Tensor,
33+
)
34+
35+
x = args[0]
36+
37+
# calculate t1 = x^2 + 1
38+
x2 = super().call_operator(mul_op, (x, x), {}, meta, True)
39+
t1 = super().call_operator(add_op_scalar, (x2, 1.0), {}, meta, True)
40+
41+
# t2 = sqrt(t1)
42+
t2 = super().call_operator(sqrt_op, (t1,), {}, meta, True)
43+
44+
# t3 = x + t2
45+
t3 = super().call_operator(add_op, (x, t2), {}, meta, True)
46+
47+
# out = ln(t3)
48+
out = super().call_operator(log_op, (t3,), {}, meta, True)
49+
50+
return out

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class TableOps:
5858
exir_ops.edge.aten.sinh.default: torch.sinh,
5959
exir_ops.edge.aten.acosh.default: torch.acosh,
6060
exir_ops.edge.aten.asin.default: torch.asin,
61+
exir_ops.edge.aten.asinh.default: torch.asinh,
6162
}
6263

6364
# Targets that must be treated explicitly

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def is_node_supported(
258258
exir_ops.edge.aten.atanh.default,
259259
exir_ops.edge.aten.addmm.default,
260260
exir_ops.edge.aten.masked_fill.Scalar,
261+
exir_ops.edge.aten.asinh.default,
261262
]
262263

263264
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def _match_pattern(
219219
torch.ops.aten.sign.default,
220220
torch.ops.aten.asin.default,
221221
torch.ops.aten.atanh.default,
222+
torch.ops.aten.asinh.default,
222223
]
223224

224225
_one_to_one_shared_input_qspec = [
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
input_t = Tuple[torch.Tensor] # Input x
19+
aten_op = "torch.ops.aten.asinh.default"
20+
21+
test_data_suite = {
22+
"zeros": lambda: torch.zeros(1, 5, 3, 2),
23+
"ones": lambda: torch.ones(10, 10, 10),
24+
"neg_ones": lambda: -torch.ones(10, 10, 10),
25+
"rand": lambda: (torch.rand(10, 10) - 0.5) * 20,
26+
"ramp": lambda: torch.linspace(-10.0, 10.0, steps=160),
27+
"near_zero": lambda: torch.tensor([-1e-6, 0.0, 1e-6]),
28+
"large": lambda: torch.tensor([-100.0, -10.0, 0.0, 10.0, 100.0]),
29+
"rand_4d": lambda: torch.randn(1, 3, 4, 5),
30+
}
31+
32+
33+
class Asinh(torch.nn.Module):
34+
def forward(self, x):
35+
return torch.asinh(x)
36+
37+
38+
@common.parametrize("test_data", test_data_suite)
39+
def test_asin_tosa_MI(test_data: Tuple):
40+
pipeline = TosaPipelineMI[input_t](
41+
Asinh(),
42+
(test_data(),),
43+
aten_op,
44+
exir_op=[],
45+
)
46+
pipeline.run()
47+
48+
49+
@common.parametrize("test_data", test_data_suite)
50+
def test_asin_tosa_BI(test_data: Tuple):
51+
pipeline = TosaPipelineBI[input_t](
52+
Asinh(),
53+
(test_data(),),
54+
aten_op=[],
55+
exir_op=[],
56+
)
57+
pipeline.run()
58+
59+
60+
@common.parametrize("test_data", test_data_suite)
61+
@common.XfailIfNoCorstone300
62+
def test_asin_u55_BI(test_data: Tuple):
63+
pipeline = EthosU55PipelineBI[input_t](
64+
Asinh(),
65+
(test_data(),),
66+
aten_ops=[],
67+
)
68+
pipeline.run()
69+
70+
71+
@common.parametrize("test_data", test_data_suite)
72+
@common.XfailIfNoCorstone320
73+
def test_asin_u85_BI(test_data: Tuple):
74+
pipeline = EthosU85PipelineBI[input_t](
75+
Asinh(),
76+
(test_data(),),
77+
aten_ops=[],
78+
)
79+
pipeline.run()

0 commit comments

Comments
 (0)