1515
1616# For MI case
1717edge_asin_op = (exir_ops .edge .aten .asin .default ,)
18+ edge_acos_op = (exir_ops .edge .aten .acos .default ,)
1819
1920
20- def get_asin_decomposition (op ) -> tuple :
21- if op in edge_asin_op :
21+ def get_decomposition (op ) -> tuple :
22+ if op in ( edge_asin_op + edge_acos_op ) :
2223 return (
2324 exir_ops .edge .aten .mul .Tensor ,
2425 exir_ops .edge .aten .add .Tensor ,
@@ -31,25 +32,26 @@ def get_asin_decomposition(op) -> tuple:
3132 exir_ops .edge .aten .lt .Scalar ,
3233 exir_ops .edge .aten .sub .Tensor ,
3334 exir_ops .edge .aten .full_like .default ,
34- exir_ops .edge .aten .where .self ,
3535 exir_ops .edge .aten .neg .default ,
3636 )
3737
38- raise RuntimeError (f"Can't get asin decomposition for op { op } " )
38+ raise RuntimeError (f"Can't get decomposition for op { op } " )
3939
4040
41- class DecomposeAsinPass (ArmPass ):
41+ class DecomposeAsinAndAcosPass (ArmPass ):
4242 """
43- This pass decomposes asin into a rational approximation for small values
43+ This pass decomposes asin and acos into a rational approximation for small values
4444 and a transformed rational approximation for large values.
45- Example:
46- y = asin(x)
47- Becomes:
45+
46+ The decomposition is based on the following mathematical identities:
4847 if abs(x) < 0.5:
49- y = x + P(x^2) / Q(x^2)
48+ asin(x) = x + P(x^2) / Q(x^2)
49+ acos(x) = π/2 - asin(x)
5050 else:
51- y = π/2 - 2 * (s + s^3 * Q(z) / P(z))
52- where P and Q are polynomials defined in the function.
51+ asin(x) = π/2 - 2 * (s + s^3 * Q(z) / P(z))
52+ acos(x) = 2 * (s + s^3 * Q(z) / P(z))
53+ where P and Q are polynomials defined in the function and s is the square root of z.
54+
5355 """
5456
5557 def _build_polynomial (
@@ -84,11 +86,25 @@ def _build_polynomial(
8486 )
8587 return result
8688
89+ def _combine_branches (
90+ self ,
91+ bool_op ,
92+ bool_args : tuple [torch .Tensor , float ],
93+ branches : tuple [torch .Tensor , torch .Tensor ],
94+ meta : dict [str , str ],
95+ ) -> torch .Tensor :
96+ where_op = exir_ops .edge .aten .where .self
97+ mask = super ().call_operator (bool_op , bool_args , {}, meta , True )
98+ branch_true , branch_false = branches
99+ return super ().call_operator (
100+ where_op , (mask , branch_true , branch_false ), {}, meta , True
101+ )
102+
87103 def call_operator (self , op , args , kwargs , meta ):
88- if op not in edge_asin_op :
104+ if op not in ( edge_asin_op + edge_acos_op ) :
89105 return super ().call_operator (op , args , kwargs , meta )
90106 logging .info (
91- f"Approximating asin . This may introduce small numerical errors. For details, see { __file__ } ."
107+ f"Approximating { op } . This may introduce small numerical errors. For details, see { __file__ } ."
92108 )
93109 x = args [0 ]
94110 half = 0.5
@@ -111,9 +127,8 @@ def call_operator(self, op, args, kwargs, meta):
111127 lt_op ,
112128 sub_op ,
113129 full_like_op ,
114- where_op ,
115130 neg_op ,
116- ) = get_asin_decomposition (op )
131+ ) = get_decomposition (op )
117132
118133 # Coefficients for the rational approximation, calculated with the Minimax (Remez) method
119134 p_coefficients = [
@@ -129,7 +144,6 @@ def call_operator(self, op, args, kwargs, meta):
129144 x_abs = super ().call_operator (abs_op , (x ,), {}, meta , True )
130145
131146 # Step 1: compute asin_small - rational approximation for [0,0.5]
132-
133147 y = super ().call_operator (mul_op , (x_abs , x_abs ), {}, meta , True )
134148 x3 = super ().call_operator (mul_op , (x_abs , y ), {}, meta , True )
135149
@@ -154,47 +168,40 @@ def call_operator(self, op, args, kwargs, meta):
154168 Qz = self ._build_polynomial (q_coefficients , z , meta )
155169
156170 numer = super ().call_operator (mul_op , (s3 , Pz ), {}, meta , True )
171+
157172 # Calculate r_large = P(z) / Q(z)
158173 r_large = super ().call_operator (div_op , (numer , Qz ), {}, meta , True )
159174
160175 # Calculate asin_large = pi/2 - 2 * (s + s^3 * Q(z) / P(z))
161176 t1 = super ().call_operator (add_op , (s , r_large ), {}, meta , True )
162177 t2 = super ().call_operator (mul_op_scalar , (t1 , two ), {}, meta , True )
178+
163179 diff = super ().call_operator (sub_op_scalar , (t2 , pi_over_2 ), {}, meta , True )
164180 tmp_neg_ones = super ().call_operator (
165181 full_like_op , (diff , neg_one ), {}, meta , True
166182 )
167183 asin_large = super ().call_operator (mul_op , (diff , tmp_neg_ones ), {}, meta , True )
168184
169- # Combine branches
170- is_large = super ().call_operator (gt_op , (x_abs , half ), {}, meta , True )
171- asin_unsigned = super ().call_operator (
172- where_op ,
173- (
174- is_large ,
175- asin_large ,
176- asin_small ,
177- ),
178- {},
179- meta ,
180- True ,
185+ asin_unsigned = self ._combine_branches (
186+ gt_op , (x_abs , half ), (asin_large , asin_small ), meta
181187 )
182188
183189 # Handle x < 0
184- is_neg = super ().call_operator (lt_op , (x , zero ), {}, meta , True )
185- # Compute -asin_unsigned
186190 negated_asin = super ().call_operator (neg_op , (asin_unsigned ,), {}, meta , True )
187- # Combine branches for signed asin
188- asin_signed = super ().call_operator (
189- where_op ,
190- (
191- is_neg ,
192- negated_asin ,
193- asin_unsigned ,
194- ),
195- {},
196- meta ,
197- True ,
191+ asin = self ._combine_branches (
192+ lt_op , (x , zero ), (negated_asin , asin_unsigned ), meta
198193 )
199194
200- return asin_signed
195+ if op in edge_acos_op :
196+ # If x <= 0.5: acos(x) = pi/2 - asin(x)
197+ const_tensor = super ().call_operator (
198+ full_like_op , (x , pi_over_2 ), {}, meta , True
199+ )
200+ acos_small = super ().call_operator (
201+ sub_op , (const_tensor , asin ), {}, meta , True
202+ )
203+ # If x > 0.5, acos(x) = 2 * (s + s^3 * Q(z) / P(z)) = t2
204+ acos = self ._combine_branches (gt_op , (x , half ), (t2 , acos_small ), meta )
205+ return acos
206+
207+ return asin
0 commit comments