Skip to content

Commit fe4b480

Browse files
committed
Documentation and tests
1 parent 4d54209 commit fe4b480

File tree

6 files changed

+94
-20
lines changed

6 files changed

+94
-20
lines changed

src/Symbolics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ include("solver/polynomialization.jl")
210210
include("solver/attract.jl")
211211
include("solver/ia_main.jl")
212212
include("solver/main.jl")
213-
include("solver/ia_rules.jl")
213+
include("solver/special_cases.jl")
214214
export symbolic_solve
215215

216216
function symbolics_to_sympy end

src/solver/ia_main.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
4545

4646
if isequal(old_lhs, lhs)
4747
warns && @warn("This expression cannot be solved with the methods available to ia_solve. Try a numerical method instead.")
48-
return nothing
48+
return nothing, conditions
4949
end
5050

5151
old_lhs = deepcopy(lhs)
@@ -82,7 +82,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
8282
else
8383
# 2 / x = y
8484
lhs = args[2]
85-
rhs = map(sol -> args[1] // sol, rhs)
85+
rhs = map(sol -> term(/, args[1], sol), rhs)
8686
end
8787

8888
elseif oper === (^)
@@ -114,6 +114,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
114114
elseif any(isequal(x, var) for x in get_variables(args[1])) &&
115115
n_occurrences(args[2], var) == 0
116116
lhs = args[1]
117+
s, args[2] = filter_stuff(args[2])
117118
rhs = map(sol -> term(^, sol, 1 // args[2]), rhs)
118119
else
119120
lhs = args[2]
@@ -280,9 +281,9 @@ function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots =
280281
conditions = []
281282
if nx == 0
282283
warns && @warn("Var not present in given expression")
283-
return []
284+
return nothing
284285
elseif nx == 1
285-
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
286+
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
286287
elseif nx > 1
287288
sols, conditions = attract(lhs, var; warns = warns, complex_roots, periodic_roots)
288289
end

src/solver/main.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,6 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
173173
expr = Vector{Num}(expr)
174174
end
175175

176-
if expr_univar && !x_univar
177-
expr = [expr]
178-
expr_univar = false
179-
end
180176
if !expr_univar && x_univar
181177
x = [x]
182178
x_univar = false
@@ -189,8 +185,17 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
189185
isequal(sols, nothing) && return nothing
190186
sols = map(postprocess_root, sols)
191187
return sols
188+
elseif expr_univar
189+
all_vars = get_variables(expr)
190+
diff_vars = setdiff(wrap.(all_vars), x)
191+
if length(diff_vars) == 1
192+
return solve_interms_ofvar(expr, diff_vars[1], dropmultiplicity=dropmultiplicity, warns=warns)
193+
end
194+
195+
expr = [expr]
192196
end
193197

198+
194199
if !x_univar
195200
for e in expr
196201
for var in x
@@ -237,6 +242,7 @@ function symbolic_solve(expr; x...)
237242
return symbolic_solve(expr, vars; x...)
238243
end
239244

245+
240246
"""
241247
solve_univar(expression, x; dropmultiplicity=true)
242248
This solver uses analytic solutions up to degree 4 to solve univariate polynomials.
@@ -256,6 +262,8 @@ implemented in the function `get_roots` and its children.
256262
257263
- dropmultiplicity (optional): Print repeated roots or not?
258264
265+
- strict (optional): Bool that enables/disables strict assert if input expression is a univariate polynomial or not. If strict=true and expression is not a polynomial, `solve_univar` throws an assertion error.
266+
259267
# Examples
260268
261269
"""
@@ -277,8 +285,8 @@ function solve_univar(expression, x; dropmultiplicity=true, strict=true)
277285
end
278286

279287
subs, filtered_expr, assumptions = filter_poly(expression, x, assumptions=true)
280-
if strict
281-
@assert check_polynomial(filtered_expr) "This expression could not be solved by `symbolic_solve`."
288+
if !strict && !check_polynomial(filtered_expr, strict=false)
289+
return [RootsOf(wrap(expression), wrap(x))]
282290
end
283291
coeffs, constant = polynomial_coeffs(filtered_expr, [x])
284292
degree = sdegree(coeffs, x)

src/solver/postprocess.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,26 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
4343
end
4444
end
4545

46+
args = arguments(x)
47+
4648
# (X)^0 => 1
47-
if oper === (^) && isequal(arguments(x)[2], 0)
49+
if oper === (^) && isequal(args[2], 0) && !isequal(args[1], 0)
4850
return 1
4951
end
5052

5153
# (X)^1 => X
52-
if oper === (^) && isequal(arguments(x)[2], 1)
53-
return arguments(x)[1]
54+
if oper === (^) && isequal(args[2], 1)
55+
return args[1]
56+
end
57+
58+
# (0)^X => 0
59+
if oper === (^) && isequal(args[1], 0) && !isequal(args[2], 0)
60+
return 0
61+
end
62+
63+
# y / 0 => Inf
64+
if oper === (/) && !isequal(args[1], 0) && isequal(args[2], 0)
65+
return Inf
5466
end
5567

5668
# sqrt((N / D)^2 * M) => N / D * sqrt(M)

src/solver/ia_rules.jl renamed to src/solver/special_cases.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,47 @@ function cross_multiply(eq)
3737
return cross_multiply(eq)
3838
end
3939
end
40+
"""
41+
solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
42+
This special case solver expects a single equation in multiple variables and a
43+
variable `s` (this can be any Num, `s` is used for convenience). The function generates
44+
a system of equations to by observing the coefficients of the powers of `s` present in `eq`.
45+
E.g. a system would look like `a+b = 1`, `a-2b = 3` for the eq `(a+b)s + (a-2b)s^2 - (1)s - (3)s^2 = 0`.
46+
After generating this system, it calls `symbolic_solve`, which uses `solve_multivar`. `symbolic_solve` was chosen
47+
instead of `solve_multivar` because it postprocesses the roots in order to simplify them and make them more user friendly.
4048
49+
Generation of system uses cross multiplication in order to simplify the equation and convert it
50+
to a polynomial like shape.
51+
52+
53+
# Arguments
54+
- eq: Single symbolics Num or SymbolicUtils.BasicSymbolic. This is equated to 0 and then solved. E.g. `expr = x+2`, we solve `x+2 = 0`
55+
56+
- s: Variable to "isolate", i.e. ignore and generate the system of equations based on this variable's coefficients.
57+
58+
- dropmultiplicity (optional): Print repeated roots or not?
59+
60+
- warns (optional, this is not used currently): Warn user when something is wrong or not.
61+
62+
# Examples
63+
```jldoctest
64+
julia> @variables a b x s;
65+
66+
julia> eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3;
67+
68+
julia> Symbolics.solve_interms_ofvar(eq, s)
69+
2-element Vector{Any}:
70+
Dict{Num, Any}(a => -1//10, b => 3//2, x => (0 - 1im)*√(5))
71+
Dict{Num, Any}(a => -1//10, b => 3//2, x => (0 + 1im)*√(5))
72+
```
73+
```jldoctest
74+
julia> eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d));
75+
76+
julia> Symbolics.solve_interms_ofvar(eq, s)
77+
1-element Vector{Any}:
78+
Dict{Num, Any}(a => 1, d => 1, b => 1, c => 1)
79+
```
80+
"""
4181
function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
4282
@assert iscall(unwrap(eq))
4383
vars = Symbolics.get_variables(eq)

test/solver.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,24 @@ end
6868
@testset "Solving in terms of a constant var" begin
6969
eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d))
7070
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,c,d])
71+
solve_roots = sort_arr(symbolic_solve(eq, [a,b,c,d]), [a,b,c,d])
7172
known_roots = sort_arr([Dict(a=>1, b=>1, c=>1, d=>1)], [a,b,c,d])
7273
@test check_approx(calcd_roots, known_roots)
74+
@test check_approx(solve_roots, known_roots)
7375

7476
eq = (a+b)*s^2 - 2s^2 + 2*b*s - 3*s
7577
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
78+
solve_roots = sort_arr(symbolic_solve(eq, [a,b]), [a,b])
7679
known_roots = sort_arr([Dict(a=>1/2, b=>3/2)], [a,b])
7780
@test check_approx(calcd_roots, known_roots)
81+
@test check_approx(solve_roots, known_roots)
7882

7983
eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3
80-
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
84+
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,x])
85+
solve_roots = sort_arr(symbolic_solve(eq, [a,b,x]), [a,b,x])
8186
known_roots = sort_arr([Dict(a=>-1/10, b=>3/2, x=>-im*sqrt(5)), Dict(a=>-1/10, b=>3/2, x=>im*sqrt(5))], [a,b,x])
8287
@test check_approx(calcd_roots, known_roots)
88+
@test check_approx(solve_roots, known_roots)
8389
end
8490

8591
@testset "Invalid input" begin
@@ -345,14 +351,18 @@ end
345351
@testset "Post Process roots" begin
346352
SymbolicUtils.@syms __x
347353
__symsqrt(x) = SymbolicUtils.term(ssqrt, x)
354+
term = SymbolicUtils.term
348355
@test Symbolics.postprocess_root(2 // 1) == 2 && Symbolics.postprocess_root(2 + 0*im) == 2
349356
@test Symbolics.postprocess_root(__symsqrt(4)) == 2
350357
@test isequal(Symbolics.postprocess_root(__symsqrt(__x)^2), __x)
351358

352-
@test Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 0) ) == 1
353-
@test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.e, 0) ) == 1
354-
@test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.pi, 1) ) == Base.MathConstants.pi
355-
@test isequal(Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 1) ), __x)
359+
360+
@test isequal(Symbolics.postprocess_root(term(^, 0, __x)), 0)
361+
@test_broken isequal(Symbolics.postprocess_root(term(/, __x, 0)), Inf)
362+
@test Symbolics.postprocess_root(term(^, __x, 0) ) == 1
363+
@test Symbolics.postprocess_root(term(^, Base.MathConstants.e, 0) ) == 1
364+
@test Symbolics.postprocess_root(term(^, Base.MathConstants.pi, 1) ) == Base.MathConstants.pi
365+
@test isequal(Symbolics.postprocess_root(term(^, __x, 1) ), __x)
356366

357367
x = Symbolics.term(sqrt, 2)
358368
@test isequal(Symbolics.postprocess_root( expand((x + 1)^4) ), 17 + 12x)
@@ -416,7 +426,10 @@ end
416426
lhs = ia_solve(a*x^b + c, x)[1]
417427
lhs2 = symbolic_solve(a*x^b + c, x)[1]
418428
rhs = Symbolics.term(^, -c.val/a.val, 1/b.val)
419-
#@test isequal(lhs, rhs)
429+
@test_broken isequal(lhs, rhs)
430+
431+
@test isequal(symbolic_solve(2/x, x)[1], Inf)
432+
@test isequal(symbolic_solve(x^1.5, x)[1], 0)
420433

421434
lhs = symbolic_solve(log(a*x)-b,x)[1]
422435
@test isequal(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E)

0 commit comments

Comments
 (0)