1-
21# Alex: make sure `Num`s are not processed here as they'd break it.
32_postprocess_root (x) = x
43
@@ -32,30 +31,30 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
3231 ! iscall (x) && return x
3332
3433 x = Symbolics. term (operation (x), map (_postprocess_root, arguments (x))... )
34+ oper = operation (x)
3535
3636 # sqrt(0), cbrt(0) => 0
3737 # sqrt(1), cbrt(1) => 1
38- if iscall (x) &&
39- (operation (x) === sqrt || operation (x) === cbrt || operation (x) === ssqrt ||
40- operation (x) === scbrt)
38+ if (oper === sqrt || oper === cbrt || oper === ssqrt ||
39+ oper === scbrt)
4140 arg = arguments (x)[1 ]
4241 if isequal (arg, 0 ) || isequal (arg, 1 )
4342 return arg
4443 end
4544 end
4645
4746 # (X)^0 => 1
48- if iscall (x) && operation (x) === (^ ) && isequal (arguments (x)[2 ], 0 )
47+ if oper === (^ ) && isequal (arguments (x)[2 ], 0 )
4948 return 1
5049 end
5150
5251 # (X)^1 => X
53- if iscall (x) && operation (x) === (^ ) && isequal (arguments (x)[2 ], 1 )
52+ if oper === (^ ) && isequal (arguments (x)[2 ], 1 )
5453 return arguments (x)[1 ]
5554 end
5655
5756 # sqrt((N / D)^2 * M) => N / D * sqrt(M)
58- if iscall (x) && ( operation (x) === sqrt || operation (x) === ssqrt)
57+ if (oper === sqrt || oper === ssqrt)
5958 function squarefree_decomp (x:: Integer )
6059 square, squarefree = big (1 ), big (1 )
6160 for (p, d) in collect (Primes. factor (abs (x)))
@@ -90,7 +89,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
9089 end
9190
9291 # (sqrt(N))^M => N^div(M, 2)*sqrt(N)^(mod(M, 2))
93- if iscall (x) && operation (x) === (^ )
92+ if oper === (^ )
9493 arg1, arg2 = arguments (x)
9594 if iscall (arg1) && (operation (arg1) === sqrt || operation (arg1) === ssqrt)
9695 if arg2 isa Integer
@@ -105,6 +104,19 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
105104 end
106105 end
107106
107+ x = convert_consts (x)
108+
109+ if oper === (+ )
110+ args = arguments (x)
111+ for arg in args
112+ if isequal (arg, 0 )
113+ after_removing = setdiff (args, arg)
114+ isone (length (after_removing)) && return after_removing[1 ]
115+ return Symbolics. term (+ , after_removing)
116+ end
117+ end
118+ end
119+
108120 return x
109121end
110122
@@ -122,3 +134,54 @@ function postprocess_root(x)
122134 end
123135 x # unreachable
124136end
137+
138+
139+ inv_exacts = [0 , Symbolics. term (* , pi ),
140+ Symbolics. term (/ ,pi ,3 ),
141+ Symbolics. term (/ , pi , 2 ),
142+ Symbolics. term (/ , Symbolics. term (* , 2 , pi ), 3 ),
143+ Symbolics. term (/ , pi , 6 ),
144+ Symbolics. term (/ , Symbolics. term (* , 5 , pi ), 6 ),
145+ Symbolics. term (/ , pi , 4 )
146+ ]
147+ inv_evald = Symbolics. symbolic_to_float .(inv_exacts)
148+
149+ const inv_pairs = collect (zip (inv_exacts, inv_evald))
150+ """
151+ function convert_consts(x)
152+ This function takes BasicSymbolic terms as input (x) and attempts
153+ to simplify these basic symbolic terms using known values.
154+ Currently, this function only supports inverse trigonometric functions.
155+
156+ ## Examples
157+ ```jldoctest
158+ julia> Symbolics.convert_consts(Symbolics.term(acos, 0))
159+ π / 2
160+
161+ julia> Symbolics.convert_consts(Symbolics.term(atan, 0))
162+ 0
163+
164+ julia> Symbolics.convert_consts(Symbolics.term(atan, 1))
165+ π / 4
166+ ```
167+ """
168+ function convert_consts (x)
169+ ! iscall (x) && return x
170+
171+ oper = operation (x)
172+ inv_opers = [asin, acos, atan]
173+
174+ if any (isequal (oper, o) for o in inv_opers) && isempty (Symbolics. get_variables (x))
175+ val = Symbolics. symbolic_to_float (x)
176+ for (exact, evald) in inv_pairs
177+ if isapprox (evald, val)
178+ return exact
179+ elseif isapprox (- evald, val)
180+ return - exact
181+ end
182+ end
183+ end
184+
185+ # add [sin, cos, tan] simplifications in the future?
186+ return x
187+ end
0 commit comments