Skip to content

Commit 33289f8

Browse files
tbergkvisthinriksnaer
authored andcommitted
Arm backend: Add sinh decomposition pass and test (pytorch#11848)
Decomposes sinh into other operators/lookup table for MI/BI case. Signed-off-by: Teo Bergkvist <[email protected]> Signed-off-by: Emma Kujala <[email protected]>
1 parent 5caf20d commit 33289f8

File tree

7 files changed

+137
-0
lines changed

7 files changed

+137
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .decompose_round_pass import DecomposeRoundPass # noqa
4040
from .decompose_select import DecomposeSelectPass # noqa
4141
from .decompose_silu_pass import DecomposeSiluPass # noqa
42+
from .decompose_sinh_pass import DecomposeSinhPass # noqa
4243
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
4344
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
4445
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
DecomposeRoundPass,
4343
DecomposeSelectPass,
4444
DecomposeSiluPass,
45+
DecomposeSinhPass,
4546
DecomposeSoftmaxPass,
4647
DecomposeSoftmaxUnstablePass,
4748
DecomposeSqrtPass,
@@ -151,6 +152,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
151152
self.add_pass(DecomposeSqrtPass())
152153
self.add_pass(ConvertIntPowToMuls())
153154
self.add_pass(CastBoolToInt8Pass())
155+
self.add_pass(DecomposeSinhPass())
154156
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
155157
self.add_pass(DecomposeEmbeddingPass())
156158
self.add_pass(FuseQuantizedActivationPass())
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
11+
# For MI case
12+
edge_sinh = exir_ops.edge.aten.sinh.default
13+
14+
15+
class DecomposeSinhPass(ArmPass):
16+
"""
17+
A decomposition pass that decomposes Sinh operations into a
18+
combination of supported TOSA-equivalent operations (MI).
19+
20+
Supported input ops:
21+
- exir_ops.edge.aten.sinh.default
22+
23+
These are decomposed into exponentials, negation, subtraction,
24+
and scalar multiplication.
25+
"""
26+
27+
def call_operator(self, op, args, kwargs, meta):
28+
if op is not edge_sinh:
29+
return super().call_operator(op, args, kwargs, meta)
30+
31+
x = args
32+
33+
sub_op, exp_op, neg_op, mul_op = (
34+
exir_ops.edge.aten.sub.Tensor,
35+
exir_ops.edge.aten.exp.default,
36+
exir_ops.edge.aten.neg.default,
37+
exir_ops.edge.aten.mul.Scalar,
38+
)
39+
40+
# Exponential 1
41+
exp1 = super().call_operator(exp_op, x, {}, meta, updated=True)
42+
43+
# Exponential 2
44+
neg_x = super().call_operator(neg_op, x, {}, meta, updated=True)
45+
exp2 = super().call_operator(exp_op, (neg_x,), {}, meta, updated=True)
46+
47+
# Subtraction
48+
sub = super().call_operator(sub_op, (exp1, exp2), {}, meta, updated=True)
49+
50+
# Multiplication
51+
out = super().call_operator(mul_op, (sub, 0.5), {}, meta, updated=True)
52+
53+
return out

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class TableOps:
5353
exir_ops.edge.aten.tanh.default: torch.tanh,
5454
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5555
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
56+
exir_ops.edge.aten.sinh.default: torch.sinh,
5657
}
5758

5859
# Targets that must be treated explicitly

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def is_node_supported(
243243
torch.ops.aten.scalar_tensor.default,
244244
exir_ops.edge.aten.gelu.default,
245245
exir_ops.edge.aten.alias_copy.default,
246+
exir_ops.edge.aten.sinh.default,
246247
]
247248

248249
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def _match_pattern(
213213
torch.ops.aten.full_like.default,
214214
torch.ops.aten.pow.Tensor_Scalar,
215215
torch.ops.aten.gelu.default,
216+
torch.ops.aten.sinh.default,
216217
]
217218

218219
_one_to_one_shared_input_qspec = [

backends/arm/test/ops/test_sinh.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineBI,
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
aten_op = "torch.ops.aten.sinh.default"
18+
exir_op = "executorch_exir_dialects_edge__ops_aten__sinh_default"
19+
20+
21+
input_t1 = Tuple[torch.Tensor] # Input x
22+
23+
test_data_suite = {
24+
# (test_name, test_data)
25+
"zeros": torch.zeros(10, 10, 10),
26+
"zeros_alt_shape": torch.zeros(10, 3, 5),
27+
"ones": torch.ones(10, 10, 10),
28+
"rand": torch.rand(10, 10) - 0.5,
29+
"rand_alt_shape": torch.rand(10, 3, 5) - 0.5,
30+
"randn_pos": torch.randn(10) + 10,
31+
"randn_neg": torch.randn(10) - 10,
32+
"ramp": torch.arange(-16, 16, 0.2),
33+
"large": 100 * torch.ones(1, 1),
34+
"small": 0.000001 * torch.ones(1, 1),
35+
}
36+
37+
38+
class Sinh(torch.nn.Module):
39+
40+
def forward(self, x: torch.Tensor):
41+
return torch.sinh(x)
42+
43+
44+
@common.parametrize("test_data", test_data_suite)
45+
def test_sinh_tosa_MI(test_data: Tuple):
46+
pipeline = TosaPipelineMI[input_t1](
47+
Sinh(),
48+
(test_data,),
49+
aten_op,
50+
exir_op,
51+
)
52+
pipeline.run()
53+
54+
55+
@common.parametrize("test_data", test_data_suite)
56+
def test_sinh_tosa_BI(test_data: Tuple):
57+
pipeline = TosaPipelineBI[input_t1](
58+
Sinh(), (test_data,), aten_op=aten_op, exir_op=exir_op
59+
)
60+
pipeline.run()
61+
62+
63+
@common.XfailIfNoCorstone300
64+
@common.parametrize("test_data", test_data_suite)
65+
def test_sinh_u55_BI(test_data: Tuple):
66+
pipeline = EthosU55PipelineBI[input_t1](
67+
Sinh(), (test_data,), aten_ops=aten_op, exir_ops=exir_op
68+
)
69+
pipeline.run()
70+
71+
72+
@common.XfailIfNoCorstone320
73+
@common.parametrize("test_data", test_data_suite)
74+
def test_sinh_u85_BI(test_data: Tuple):
75+
pipeline = EthosU85PipelineBI[input_t1](
76+
Sinh(), (test_data,), aten_ops=aten_op, exir_ops=exir_op
77+
)
78+
pipeline.run()

0 commit comments

Comments
 (0)