15
15
16
16
# For MI case
17
17
edge_asin_op = (exir_ops .edge .aten .asin .default ,)
18
+ edge_acos_op = (exir_ops .edge .aten .acos .default ,)
18
19
19
20
20
- def get_asin_decomposition (op ) -> tuple :
21
- if op in edge_asin_op :
21
+ def get_decomposition (op ) -> tuple :
22
+ if op in ( edge_asin_op + edge_acos_op ) :
22
23
return (
23
24
exir_ops .edge .aten .mul .Tensor ,
24
25
exir_ops .edge .aten .add .Tensor ,
@@ -31,25 +32,26 @@ def get_asin_decomposition(op) -> tuple:
31
32
exir_ops .edge .aten .lt .Scalar ,
32
33
exir_ops .edge .aten .sub .Tensor ,
33
34
exir_ops .edge .aten .full_like .default ,
34
- exir_ops .edge .aten .where .self ,
35
35
exir_ops .edge .aten .neg .default ,
36
36
)
37
37
38
- raise RuntimeError (f"Can't get asin decomposition for op { op } " )
38
+ raise RuntimeError (f"Can't get decomposition for op { op } " )
39
39
40
40
41
- class DecomposeAsinPass (ArmPass ):
41
+ class DecomposeAsinAndAcosPass (ArmPass ):
42
42
"""
43
- This pass decomposes asin into a rational approximation for small values
43
+ This pass decomposes asin and acos into a rational approximation for small values
44
44
and a transformed rational approximation for large values.
45
- Example:
46
- y = asin(x)
47
- Becomes:
45
+
46
+ The decomposition is based on the following mathematical identities:
48
47
if abs(x) < 0.5:
49
- y = x + P(x^2) / Q(x^2)
48
+ asin(x) = x + P(x^2) / Q(x^2)
49
+ acos(x) = π/2 - asin(x)
50
50
else:
51
- y = π/2 - 2 * (s + s^3 * Q(z) / P(z))
52
- where P and Q are polynomials defined in the function.
51
+ asin(x) = π/2 - 2 * (s + s^3 * Q(z) / P(z))
52
+ acos(x) = 2 * (s + s^3 * Q(z) / P(z))
53
+ where P and Q are polynomials defined in the function and s is the square root of z.
54
+
53
55
"""
54
56
55
57
def _build_polynomial (
@@ -84,11 +86,25 @@ def _build_polynomial(
84
86
)
85
87
return result
86
88
89
+ def _combine_branches (
90
+ self ,
91
+ bool_op ,
92
+ bool_args : tuple [torch .Tensor , float ],
93
+ branches : tuple [torch .Tensor , torch .Tensor ],
94
+ meta : dict [str , str ],
95
+ ) -> torch .Tensor :
96
+ where_op = exir_ops .edge .aten .where .self
97
+ mask = super ().call_operator (bool_op , bool_args , {}, meta , True )
98
+ branch_true , branch_false = branches
99
+ return super ().call_operator (
100
+ where_op , (mask , branch_true , branch_false ), {}, meta , True
101
+ )
102
+
87
103
def call_operator (self , op , args , kwargs , meta ):
88
- if op not in edge_asin_op :
104
+ if op not in ( edge_asin_op + edge_acos_op ) :
89
105
return super ().call_operator (op , args , kwargs , meta )
90
106
logging .info (
91
- f"Approximating asin . This may introduce small numerical errors. For details, see { __file__ } ."
107
+ f"Approximating { op } . This may introduce small numerical errors. For details, see { __file__ } ."
92
108
)
93
109
x = args [0 ]
94
110
half = 0.5
@@ -111,9 +127,8 @@ def call_operator(self, op, args, kwargs, meta):
111
127
lt_op ,
112
128
sub_op ,
113
129
full_like_op ,
114
- where_op ,
115
130
neg_op ,
116
- ) = get_asin_decomposition (op )
131
+ ) = get_decomposition (op )
117
132
118
133
# Coefficients for the rational approximation, calculated with the Minimax (Remez) method
119
134
p_coefficients = [
@@ -129,7 +144,6 @@ def call_operator(self, op, args, kwargs, meta):
129
144
x_abs = super ().call_operator (abs_op , (x ,), {}, meta , True )
130
145
131
146
# Step 1: compute asin_small - rational approximation for [0,0.5]
132
-
133
147
y = super ().call_operator (mul_op , (x_abs , x_abs ), {}, meta , True )
134
148
x3 = super ().call_operator (mul_op , (x_abs , y ), {}, meta , True )
135
149
@@ -154,47 +168,40 @@ def call_operator(self, op, args, kwargs, meta):
154
168
Qz = self ._build_polynomial (q_coefficients , z , meta )
155
169
156
170
numer = super ().call_operator (mul_op , (s3 , Pz ), {}, meta , True )
171
+
157
172
# Calculate r_large = P(z) / Q(z)
158
173
r_large = super ().call_operator (div_op , (numer , Qz ), {}, meta , True )
159
174
160
175
# Calculate asin_large = pi/2 - 2 * (s + s^3 * Q(z) / P(z))
161
176
t1 = super ().call_operator (add_op , (s , r_large ), {}, meta , True )
162
177
t2 = super ().call_operator (mul_op_scalar , (t1 , two ), {}, meta , True )
178
+
163
179
diff = super ().call_operator (sub_op_scalar , (t2 , pi_over_2 ), {}, meta , True )
164
180
tmp_neg_ones = super ().call_operator (
165
181
full_like_op , (diff , neg_one ), {}, meta , True
166
182
)
167
183
asin_large = super ().call_operator (mul_op , (diff , tmp_neg_ones ), {}, meta , True )
168
184
169
- # Combine branches
170
- is_large = super ().call_operator (gt_op , (x_abs , half ), {}, meta , True )
171
- asin_unsigned = super ().call_operator (
172
- where_op ,
173
- (
174
- is_large ,
175
- asin_large ,
176
- asin_small ,
177
- ),
178
- {},
179
- meta ,
180
- True ,
185
+ asin_unsigned = self ._combine_branches (
186
+ gt_op , (x_abs , half ), (asin_large , asin_small ), meta
181
187
)
182
188
183
189
# Handle x < 0
184
- is_neg = super ().call_operator (lt_op , (x , zero ), {}, meta , True )
185
- # Compute -asin_unsigned
186
190
negated_asin = super ().call_operator (neg_op , (asin_unsigned ,), {}, meta , True )
187
- # Combine branches for signed asin
188
- asin_signed = super ().call_operator (
189
- where_op ,
190
- (
191
- is_neg ,
192
- negated_asin ,
193
- asin_unsigned ,
194
- ),
195
- {},
196
- meta ,
197
- True ,
191
+ asin = self ._combine_branches (
192
+ lt_op , (x , zero ), (negated_asin , asin_unsigned ), meta
198
193
)
199
194
200
- return asin_signed
195
+ if op in edge_acos_op :
196
+ # If x <= 0.5: acos(x) = pi/2 - asin(x)
197
+ const_tensor = super ().call_operator (
198
+ full_like_op , (x , pi_over_2 ), {}, meta , True
199
+ )
200
+ acos_small = super ().call_operator (
201
+ sub_op , (const_tensor , asin ), {}, meta , True
202
+ )
203
+ # If x > 0.5, acos(x) = 2 * (s + s^3 * Q(z) / P(z)) = t2
204
+ acos = self ._combine_branches (gt_op , (x , half ), (t2 , acos_small ), meta )
205
+ return acos
206
+
207
+ return asin
0 commit comments