Skip to content

Commit c87f256

Browse files
authored
Arm backend: Add decomposition and test for acos (#13414)
Add decomposition and test for acos Signed-off-by: Emma Kujala <[email protected]>
1 parent 46dd51a commit c87f256

File tree

7 files changed

+175
-46
lines changed

7 files changed

+175
-46
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
2626
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
2727
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
28-
from .decompose_asin_pass import DecomposeAsinPass # noqa
28+
from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa
2929
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
3030
from .decompose_atan_pass import DecomposeAtanPass # noqa
3131
from .decompose_atanh_pass import DecomposeAtanhPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
DecomposeAcoshPass,
3131
DecomposeAdaptiveAvgPool2dPass,
3232
DecomposeAddmmPass,
33+
DecomposeAsinAndAcosPass,
3334
DecomposeAsinhPass,
34-
DecomposeAsinPass,
3535
DecomposeAtanhPass,
3636
DecomposeAtanPass,
3737
DecomposeAvgPool2d,
@@ -171,9 +171,9 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
171171
self.add_pass(DecomposeMaskedFill())
172172
self.add_pass(DecomposeRoundPass())
173173
self.add_pass(DecomposeAcoshPass())
174-
self.add_pass(DecomposeAsinPass())
175174
self.add_pass(DecomposeAsinhPass())
176175
self.add_pass(DecomposeCoshPass())
176+
self.add_pass(DecomposeAsinAndAcosPass())
177177
self.add_pass(DecomposeSqrtPass())
178178
self.add_pass(DecomposeAtanPass())
179179
self.add_pass(DecomposeAtanhPass())

backends/arm/_passes/decompose_asin_pass.py renamed to backends/arm/_passes/decompose_asin_and_acos_pass.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616
# For MI case
1717
edge_asin_op = (exir_ops.edge.aten.asin.default,)
18+
edge_acos_op = (exir_ops.edge.aten.acos.default,)
1819

1920

20-
def get_asin_decomposition(op) -> tuple:
21-
if op in edge_asin_op:
21+
def get_decomposition(op) -> tuple:
22+
if op in (edge_asin_op + edge_acos_op):
2223
return (
2324
exir_ops.edge.aten.mul.Tensor,
2425
exir_ops.edge.aten.add.Tensor,
@@ -31,25 +32,26 @@ def get_asin_decomposition(op) -> tuple:
3132
exir_ops.edge.aten.lt.Scalar,
3233
exir_ops.edge.aten.sub.Tensor,
3334
exir_ops.edge.aten.full_like.default,
34-
exir_ops.edge.aten.where.self,
3535
exir_ops.edge.aten.neg.default,
3636
)
3737

38-
raise RuntimeError(f"Can't get asin decomposition for op {op}")
38+
raise RuntimeError(f"Can't get decomposition for op {op}")
3939

4040

41-
class DecomposeAsinPass(ArmPass):
41+
class DecomposeAsinAndAcosPass(ArmPass):
4242
"""
43-
This pass decomposes asin into a rational approximation for small values
43+
This pass decomposes asin and acos into a rational approximation for small values
4444
and a transformed rational approximation for large values.
45-
Example:
46-
y = asin(x)
47-
Becomes:
45+
46+
The decomposition is based on the following mathematical identities:
4847
if abs(x) < 0.5:
49-
y = x + P(x^2) / Q(x^2)
48+
asin(x) = x + P(x^2) / Q(x^2)
49+
acos(x) = π/2 - asin(x)
5050
else:
51-
y = π/2 - 2 * (s + s^3 * Q(z) / P(z))
52-
where P and Q are polynomials defined in the function.
51+
asin(x) = π/2 - 2 * (s + s^3 * Q(z) / P(z))
52+
acos(x) = 2 * (s + s^3 * Q(z) / P(z))
53+
where P and Q are polynomials defined in the function and s is the square root of z.
54+
5355
"""
5456

5557
def _build_polynomial(
@@ -84,11 +86,25 @@ def _build_polynomial(
8486
)
8587
return result
8688

89+
def _combine_branches(
90+
self,
91+
bool_op,
92+
bool_args: tuple[torch.Tensor, float],
93+
branches: tuple[torch.Tensor, torch.Tensor],
94+
meta: dict[str, str],
95+
) -> torch.Tensor:
96+
where_op = exir_ops.edge.aten.where.self
97+
mask = super().call_operator(bool_op, bool_args, {}, meta, True)
98+
branch_true, branch_false = branches
99+
return super().call_operator(
100+
where_op, (mask, branch_true, branch_false), {}, meta, True
101+
)
102+
87103
def call_operator(self, op, args, kwargs, meta):
88-
if op not in edge_asin_op:
104+
if op not in (edge_asin_op + edge_acos_op):
89105
return super().call_operator(op, args, kwargs, meta)
90106
logging.info(
91-
f"Approximating asin. This may introduce small numerical errors. For details, see {__file__}."
107+
f"Approximating {op}. This may introduce small numerical errors. For details, see {__file__}."
92108
)
93109
x = args[0]
94110
half = 0.5
@@ -111,9 +127,8 @@ def call_operator(self, op, args, kwargs, meta):
111127
lt_op,
112128
sub_op,
113129
full_like_op,
114-
where_op,
115130
neg_op,
116-
) = get_asin_decomposition(op)
131+
) = get_decomposition(op)
117132

118133
# Coefficients for the rational approximation, calculated with the Minimax (Remez) method
119134
p_coefficients = [
@@ -129,7 +144,6 @@ def call_operator(self, op, args, kwargs, meta):
129144
x_abs = super().call_operator(abs_op, (x,), {}, meta, True)
130145

131146
# Step 1: compute asin_small - rational approximation for [0,0.5]
132-
133147
y = super().call_operator(mul_op, (x_abs, x_abs), {}, meta, True)
134148
x3 = super().call_operator(mul_op, (x_abs, y), {}, meta, True)
135149

@@ -154,47 +168,40 @@ def call_operator(self, op, args, kwargs, meta):
154168
Qz = self._build_polynomial(q_coefficients, z, meta)
155169

156170
numer = super().call_operator(mul_op, (s3, Pz), {}, meta, True)
171+
157172
# Calculate r_large = P(z) / Q(z)
158173
r_large = super().call_operator(div_op, (numer, Qz), {}, meta, True)
159174

160175
# Calculate asin_large = pi/2 - 2 * (s + s^3 * Q(z) / P(z))
161176
t1 = super().call_operator(add_op, (s, r_large), {}, meta, True)
162177
t2 = super().call_operator(mul_op_scalar, (t1, two), {}, meta, True)
178+
163179
diff = super().call_operator(sub_op_scalar, (t2, pi_over_2), {}, meta, True)
164180
tmp_neg_ones = super().call_operator(
165181
full_like_op, (diff, neg_one), {}, meta, True
166182
)
167183
asin_large = super().call_operator(mul_op, (diff, tmp_neg_ones), {}, meta, True)
168184

169-
# Combine branches
170-
is_large = super().call_operator(gt_op, (x_abs, half), {}, meta, True)
171-
asin_unsigned = super().call_operator(
172-
where_op,
173-
(
174-
is_large,
175-
asin_large,
176-
asin_small,
177-
),
178-
{},
179-
meta,
180-
True,
185+
asin_unsigned = self._combine_branches(
186+
gt_op, (x_abs, half), (asin_large, asin_small), meta
181187
)
182188

183189
# Handle x < 0
184-
is_neg = super().call_operator(lt_op, (x, zero), {}, meta, True)
185-
# Compute -asin_unsigned
186190
negated_asin = super().call_operator(neg_op, (asin_unsigned,), {}, meta, True)
187-
# Combine branches for signed asin
188-
asin_signed = super().call_operator(
189-
where_op,
190-
(
191-
is_neg,
192-
negated_asin,
193-
asin_unsigned,
194-
),
195-
{},
196-
meta,
197-
True,
191+
asin = self._combine_branches(
192+
lt_op, (x, zero), (negated_asin, asin_unsigned), meta
198193
)
199194

200-
return asin_signed
195+
if op in edge_acos_op:
196+
# If x <= 0.5: acos(x) = pi/2 - asin(x)
197+
const_tensor = super().call_operator(
198+
full_like_op, (x, pi_over_2), {}, meta, True
199+
)
200+
acos_small = super().call_operator(
201+
sub_op, (const_tensor, asin), {}, meta, True
202+
)
203+
# If x > 0.5, acos(x) = 2 * (s + s^3 * Q(z) / P(z)) = t2
204+
acos = self._combine_branches(gt_op, (x, half), (t2, acos_small), meta)
205+
return acos
206+
207+
return asin

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class TableOps:
6161
exir_ops.edge.aten.asin.default: torch.asin,
6262
exir_ops.edge.aten.asinh.default: torch.asinh,
6363
exir_ops.edge.aten.cosh.default: torch.cosh,
64+
exir_ops.edge.aten.acos.default: torch.acos,
6465
}
6566

6667
# Targets that must be treated explicitly

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def is_node_supported(
261261
exir_ops.edge.aten.cosh.default,
262262
exir_ops.edge.aten.glu.default,
263263
exir_ops.edge.aten.logit.default,
264+
exir_ops.edge.aten.acos.default,
264265
]
265266

266267
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def _match_pattern(
289289
torch.ops.aten.atanh.default,
290290
torch.ops.aten.asinh.default,
291291
torch.ops.aten.cosh.default,
292+
torch.ops.aten.acos.default,
292293
]
293294

294295
_one_to_one_shared_input_qspec = [

backends/arm/test/ops/test_acos.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Tuple
6+
7+
import torch
8+
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineINT,
12+
EthosU85PipelineINT,
13+
TosaPipelineFP,
14+
TosaPipelineINT,
15+
VgfPipeline,
16+
)
17+
18+
input_t = Tuple[torch.Tensor]
19+
aten_op = "torch.ops.aten.acos.default"
20+
exir_op = "executorch_exir_dialects_edge__ops_aten__acos_default"
21+
22+
23+
test_data_suite = {
24+
"ones": lambda: torch.ones(1, 7, 10, 12),
25+
"rand_in_range": lambda: (torch.rand(10, 10) - 0.5) * 2, # Uniform in [-1, 1)
26+
"ramp_valid": lambda: torch.linspace(-1.0, 1.0, steps=160),
27+
"edge_cases": lambda: torch.tensor([-1.0, 0.0, 1.0]),
28+
"1d_tensor": lambda: torch.linspace(-1.0, 1.0, steps=10), # Shape: [10]
29+
"2d_batch": lambda: torch.tensor(
30+
[[-1.0, -0.5, 0.0, 0.5, 1.0], [0.9, -0.9, 0.3, -0.3, 0.0]]
31+
), # Shape: [2, 5]
32+
"3d_batch": lambda: torch.rand(4, 5, 6) * 2 - 1, # Shape: [4, 5, 6] in [-1, 1)
33+
"3d_mixed_shape": lambda: (torch.rand(7, 15, 2) - 0.5) * 2,
34+
"4d_mixed": lambda: torch.linspace(-1, 1, steps=1 * 3 * 4 * 5).reshape(
35+
1, 3, 4, 5
36+
), # Shape: [2, 3, 4, 5]
37+
"4d_random": lambda: (torch.rand(1, 5, 10, 7) - 0.5) * 2,
38+
"bool_casted": lambda: torch.ones(3, 3, dtype=torch.bool).to(
39+
dtype=torch.float32
40+
), # All 1.0 (edge case)
41+
}
42+
43+
44+
class Acos(torch.nn.Module):
45+
46+
def forward(self, x: torch.Tensor):
47+
return torch.acos(x)
48+
49+
50+
@common.parametrize("test_data", test_data_suite)
51+
def test_acos_tosa_FP(test_data: Tuple):
52+
pipeline = TosaPipelineFP[input_t](
53+
Acos(),
54+
(test_data(),),
55+
aten_op,
56+
exir_op=exir_op,
57+
)
58+
pipeline.run()
59+
60+
61+
@common.parametrize("test_data", test_data_suite)
62+
def test_acos_tosa_INT(test_data: Tuple):
63+
pipeline = TosaPipelineINT[input_t](
64+
Acos(),
65+
(test_data(),),
66+
aten_op=aten_op,
67+
exir_op=exir_op,
68+
)
69+
pipeline.run()
70+
71+
72+
@common.parametrize("test_data", test_data_suite)
73+
@common.XfailIfNoCorstone300
74+
def test_acos_u55_INT(test_data: Tuple):
75+
pipeline = EthosU55PipelineINT[input_t](
76+
Acos(),
77+
(test_data(),),
78+
aten_ops=aten_op,
79+
exir_ops=exir_op,
80+
)
81+
pipeline.run()
82+
83+
84+
@common.parametrize("test_data", test_data_suite)
85+
@common.XfailIfNoCorstone320
86+
def test_acos_u85_INT(test_data: Tuple):
87+
pipeline = EthosU85PipelineINT[input_t](
88+
Acos(),
89+
(test_data(),),
90+
aten_ops=aten_op,
91+
exir_ops=exir_op,
92+
)
93+
pipeline.run()
94+
95+
96+
@common.parametrize("test_data", test_data_suite)
97+
@common.SkipIfNoModelConverter
98+
def test_acos_vgf_FP(test_data: Tuple):
99+
pipeline = VgfPipeline[input_t](
100+
Acos(),
101+
(test_data(),),
102+
[],
103+
[],
104+
tosa_version="TOSA-1.0+FP",
105+
)
106+
pipeline.run()
107+
108+
109+
@common.parametrize("test_data", test_data_suite)
110+
@common.SkipIfNoModelConverter
111+
def test_acos_vgf_INT(test_data: Tuple):
112+
pipeline = VgfPipeline[input_t](
113+
Acos(),
114+
(test_data(),),
115+
[],
116+
[],
117+
tosa_version="TOSA-1.0+INT",
118+
)
119+
pipeline.run()

0 commit comments

Comments
 (0)