Skip to content

Commit c31c3fd

Browse files
Merge pull request #1357 from n0rbed/bug_fixes
Bug fixes, integration of solve_interms_ofvar and some changes
2 parents 2bc3c54 + 0368c3e commit c31c3fd

File tree

8 files changed

+117
-35
lines changed

8 files changed

+117
-35
lines changed

src/Symbolics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ include("solver/polynomialization.jl")
210210
include("solver/attract.jl")
211211
include("solver/ia_main.jl")
212212
include("solver/main.jl")
213-
include("solver/ia_rules.jl")
213+
include("solver/special_cases.jl")
214214
export symbolic_solve
215215

216216
function symbolics_to_sympy end

src/solver/ia_main.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
88
lhs = unwrap(lhs)
99

1010
old_lhs = nothing
11+
1112
while !isequal(lhs, var)
1213
subs, poly = filter_poly(lhs, var)
1314

14-
if check_poly_inunivar(poly, var)
15+
if check_polynomial(poly, strict=false)
1516
roots = []
1617
new_var = gensym()
1718
new_var = (@variables $new_var)[1]
@@ -20,7 +21,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
2021
else
2122
a, b, islin = linear_expansion(lhs - new_var, var)
2223
if islin
23-
lhs_roots = [-b / a]
24+
lhs_roots = [-b // a]
2425
else
2526
lhs_roots = [RootsOf(lhs - new_var, var)]
2627
if warns
@@ -31,15 +32,20 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
3132

3233
for i in eachindex(lhs_roots)
3334
for j in eachindex(rhs)
34-
push!(roots, substitute(lhs_roots[i], Dict(new_var=>rhs[j]), fold=false))
35+
if iscall(lhs_roots[i]) && operation(lhs_roots[i]) == RootsOf
36+
lhs_roots[i].arguments[1] = substitute(lhs_roots[i].arguments[1], Dict(new_var=>rhs[j]), fold=false)
37+
push!(roots, lhs_roots[i])
38+
else
39+
push!(roots, substitute(lhs_roots[i], Dict(new_var=>rhs[j]), fold=false))
40+
end
3541
end
3642
end
3743
return roots, conditions
3844
end
3945

4046
if isequal(old_lhs, lhs)
4147
warns && @warn("This expression cannot be solved with the methods available to ia_solve. Try a numerical method instead.")
42-
return nothing
48+
return nothing, conditions
4349
end
4450

4551
old_lhs = deepcopy(lhs)
@@ -76,7 +82,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
7682
else
7783
# 2 / x = y
7884
lhs = args[2]
79-
rhs = map(sol -> args[1] // sol, rhs)
85+
rhs = map(sol -> term(/, args[1], sol), rhs)
8086
end
8187

8288
elseif oper === (^)
@@ -108,6 +114,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
108114
elseif any(isequal(x, var) for x in get_variables(args[1])) &&
109115
n_occurrences(args[2], var) == 0
110116
lhs = args[1]
117+
s, args[2] = filter_stuff(args[2])
111118
rhs = map(sol -> term(^, sol, 1 // args[2]), rhs)
112119
else
113120
lhs = args[2]
@@ -169,7 +176,7 @@ function attract(lhs, var; warns = true, complex_roots = true, periodic_roots =
169176
return nothing, conditions
170177
end
171178
end
172-
179+
173180
new_var = collect(keys(sub))[1]
174181
new_var_val = collect(values(sub))[1]
175182

@@ -178,6 +185,7 @@ function attract(lhs, var; warns = true, complex_roots = true, periodic_roots =
178185
new_roots = []
179186

180187
for root in roots
188+
iscall(root) && operation(root) == RootsOf && continue
181189
new_sol, new_conds = isolate(new_var_val - root, var; warns = warns, complex_roots, periodic_roots)
182190
append!(conditions, new_conds)
183191
push!(new_roots, new_sol)
@@ -273,9 +281,9 @@ function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots =
273281
conditions = []
274282
if nx == 0
275283
warns && @warn("Var not present in given expression")
276-
return []
284+
return nothing
277285
elseif nx == 1
278-
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
286+
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
279287
elseif nx > 1
280288
sols, conditions = attract(lhs, var; warns = warns, complex_roots, periodic_roots)
281289
end

src/solver/main.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,6 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
173173
expr = Vector{Num}(expr)
174174
end
175175

176-
if expr_univar && !x_univar
177-
expr = [expr]
178-
expr_univar = false
179-
end
180176
if !expr_univar && x_univar
181177
x = [x]
182178
x_univar = false
@@ -189,8 +185,17 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
189185
isequal(sols, nothing) && return nothing
190186
sols = map(postprocess_root, sols)
191187
return sols
188+
elseif expr_univar
189+
all_vars = get_variables(expr)
190+
diff_vars = setdiff(wrap.(all_vars), x)
191+
if length(diff_vars) == 1
192+
return solve_interms_ofvar(expr, diff_vars[1], dropmultiplicity=dropmultiplicity, warns=warns)
193+
end
194+
195+
expr = [expr]
192196
end
193197

198+
194199
if !x_univar
195200
for e in expr
196201
for var in x
@@ -247,6 +252,7 @@ function symbolic_solve(expr; x...)
247252
return symbolic_solve(expr, vars; x...)
248253
end
249254

255+
250256
"""
251257
solve_univar(expression, x; dropmultiplicity=true)
252258
This solver uses analytic solutions up to degree 4 to solve univariate polynomials.
@@ -266,10 +272,12 @@ implemented in the function `get_roots` and its children.
266272
267273
- dropmultiplicity (optional): Print repeated roots or not?
268274
275+
- strict (optional): Bool that enables/disables strict assert if input expression is a univariate polynomial or not. If strict=true and expression is not a polynomial, `solve_univar` throws an assertion error.
276+
269277
# Examples
270278
271279
"""
272-
function solve_univar(expression, x; dropmultiplicity=true)
280+
function solve_univar(expression, x; dropmultiplicity=true, strict=true)
273281
args = []
274282
mult_n = 1
275283
expression = unwrap(expression)
@@ -287,6 +295,9 @@ function solve_univar(expression, x; dropmultiplicity=true)
287295
end
288296

289297
subs, filtered_expr, assumptions = filter_poly(expression, x, assumptions=true)
298+
if !strict && !check_polynomial(filtered_expr, strict=false)
299+
return [RootsOf(wrap(expression), wrap(x))]
300+
end
290301
coeffs, constant = polynomial_coeffs(filtered_expr, [x])
291302
degree = sdegree(coeffs, x)
292303

@@ -325,7 +336,6 @@ function solve_univar(expression, x; dropmultiplicity=true)
325336
end
326337

327338
if isequal(arr_roots, [])
328-
@assert check_polynomial(expression) "This expression could not be solved by `symbolic_solve`."
329339
return [RootsOf(wrap(expression), wrap(x))]
330340
end
331341

src/solver/nemo_stuff.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# Checks that the expression is a polynomial with integer or rational
22
# coefficients
3-
function check_polynomial(poly)
3+
function check_polynomial(poly; strict=true)
44
poly = wrap(poly)
55
vars = get_variables(poly)
66
distr, rem = polynomial_coeffs(poly, vars)
7-
@assert isequal(rem, 0) "Not a polynomial"
8-
@assert all(c -> c isa Integer || c isa Rational, collect(values(distr))) "Coefficients must be integer or rational"
9-
return true
7+
if strict
8+
@assert isequal(rem, 0) "Not a polynomial"
9+
@assert all(c -> c isa Integer || c isa Rational, collect(values(distr))) "Coefficients must be integer or rational"
10+
return true
11+
else
12+
return isequal(rem, 0)
13+
end
1014
end
1115

1216
# factor(x^2*y + b*x*y - a*x - a*b) -> (x*y - a)*(x + b)

src/solver/postprocess.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,26 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
4343
end
4444
end
4545

46+
args = arguments(x)
47+
4648
# (X)^0 => 1
47-
if oper === (^) && isequal(arguments(x)[2], 0)
49+
if oper === (^) && isequal(args[2], 0) && !isequal(args[1], 0)
4850
return 1
4951
end
5052

5153
# (X)^1 => X
52-
if oper === (^) && isequal(arguments(x)[2], 1)
53-
return arguments(x)[1]
54+
if oper === (^) && isequal(args[2], 1)
55+
return args[1]
56+
end
57+
58+
# (0)^X => 0
59+
if oper === (^) && isequal(args[1], 0) && !isequal(args[2], 0)
60+
return 0
61+
end
62+
63+
# y / 0 => Inf
64+
if oper === (/) && !isequal(args[1], 0) && isequal(args[2], 0)
65+
return Inf
5466
end
5567

5668
# sqrt((N / D)^2 * M) => N / D * sqrt(M)

src/solver/preprocess.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,10 @@ function clean_f(filtered_expr, var, subs)
4444

4545
if oper === (/)
4646
args = arguments(unwrapped_f)
47-
if any(isequal(var, x) for x in get_variables(args[2]))
48-
filtered_expr = expand(args[1] * args[2])
47+
if !all(isequal(var, x) for x in get_variables(args[2]))
48+
filtered_expr = args[1]
4949
push!(assumptions, substitute(args[2], subs, fold=false))
50-
return filtered_expr, assumptions
5150
end
52-
filtered_expr = args[1]
53-
@info "Assuming $(substitute(args[2], subs, fold=false) != 0)"
5451
end
5552
return filtered_expr, assumptions
5653
end

src/solver/ia_rules.jl renamed to src/solver/special_cases.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,47 @@ function cross_multiply(eq)
3737
return cross_multiply(eq)
3838
end
3939
end
40+
"""
41+
solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
42+
This special case solver expects a single equation in multiple variables and a
43+
variable `s` (this can be any Num, `s` is used for convenience). The function generates
44+
a system of equations to by observing the coefficients of the powers of `s` present in `eq`.
45+
E.g. a system would look like `a+b = 1`, `a-2b = 3` for the eq `(a+b)s + (a-2b)s^2 - (1)s - (3)s^2 = 0`.
46+
After generating this system, it calls `symbolic_solve`, which uses `solve_multivar`. `symbolic_solve` was chosen
47+
instead of `solve_multivar` because it postprocesses the roots in order to simplify them and make them more user friendly.
4048
49+
Generation of system uses cross multiplication in order to simplify the equation and convert it
50+
to a polynomial like shape.
51+
52+
53+
# Arguments
54+
- eq: Single symbolics Num or SymbolicUtils.BasicSymbolic. This is equated to 0 and then solved. E.g. `expr = x+2`, we solve `x+2 = 0`
55+
56+
- s: Variable to "isolate", i.e. ignore and generate the system of equations based on this variable's coefficients.
57+
58+
- dropmultiplicity (optional): Print repeated roots or not?
59+
60+
- warns (optional, this is not used currently): Warn user when something is wrong or not.
61+
62+
# Examples
63+
```jldoctest
64+
julia> @variables a b x s;
65+
66+
julia> eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3;
67+
68+
julia> Symbolics.solve_interms_ofvar(eq, s)
69+
2-element Vector{Any}:
70+
Dict{Num, Any}(a => -1//10, b => 3//2, x => (0 - 1im)*√(5))
71+
Dict{Num, Any}(a => -1//10, b => 3//2, x => (0 + 1im)*√(5))
72+
```
73+
```jldoctest
74+
julia> eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d));
75+
76+
julia> Symbolics.solve_interms_ofvar(eq, s)
77+
1-element Vector{Any}:
78+
Dict{Num, Any}(a => 1, d => 1, b => 1, c => 1)
79+
```
80+
"""
4181
function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
4282
@assert iscall(unwrap(eq))
4383
vars = Symbolics.get_variables(eq)

test/solver.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import Symbolics: ssqrt, slog, scbrt, symbolic_solve, ia_solve, postprocess_root
55
@test Base.get_extension(Symbolics, :SymbolicsNemoExt) === nothing
66
@variables x
77
roots = ia_solve(log(2 + x), x)
8-
@test substitute(roots[1], Dict()) == -1.0
98
roots = @test_warn ["Nemo", "required"] ia_solve(log(2 + x^2), x)
109
@test operation(roots[1]) == Symbolics.RootsOf
1110
end
@@ -69,23 +68,28 @@ end
6968
@testset "Solving in terms of a constant var" begin
7069
eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d))
7170
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,c,d])
71+
solve_roots = sort_arr(symbolic_solve(eq, [a,b,c,d]), [a,b,c,d])
7272
known_roots = sort_arr([Dict(a=>1, b=>1, c=>1, d=>1)], [a,b,c,d])
7373
@test check_approx(calcd_roots, known_roots)
74+
@test check_approx(solve_roots, known_roots)
7475

7576
eq = (a+b)*s^2 - 2s^2 + 2*b*s - 3*s
7677
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
78+
solve_roots = sort_arr(symbolic_solve(eq, [a,b]), [a,b])
7779
known_roots = sort_arr([Dict(a=>1/2, b=>3/2)], [a,b])
7880
@test check_approx(calcd_roots, known_roots)
81+
@test check_approx(solve_roots, known_roots)
7982

8083
eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3
81-
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
84+
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,x])
85+
solve_roots = sort_arr(symbolic_solve(eq, [a,b,x]), [a,b,x])
8286
known_roots = sort_arr([Dict(a=>-1/10, b=>3/2, x=>-im*sqrt(5)), Dict(a=>-1/10, b=>3/2, x=>im*sqrt(5))], [a,b,x])
8387
@test check_approx(calcd_roots, known_roots)
88+
@test check_approx(solve_roots, known_roots)
8489
end
8590

8691
@testset "Invalid input" begin
8792
@test_throws AssertionError symbolic_solve(x, x^2)
88-
@test_throws AssertionError symbolic_solve(1/x, x)
8993
end
9094

9195
@testset "Nice univar cases" begin
@@ -355,14 +359,18 @@ end
355359
@testset "Post Process roots" begin
356360
SymbolicUtils.@syms __x
357361
__symsqrt(x) = SymbolicUtils.term(ssqrt, x)
362+
term = SymbolicUtils.term
358363
@test Symbolics.postprocess_root(2 // 1) == 2 && Symbolics.postprocess_root(2 + 0*im) == 2
359364
@test Symbolics.postprocess_root(__symsqrt(4)) == 2
360365
@test isequal(Symbolics.postprocess_root(__symsqrt(__x)^2), __x)
361366

362-
@test Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 0) ) == 1
363-
@test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.e, 0) ) == 1
364-
@test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.pi, 1) ) == Base.MathConstants.pi
365-
@test isequal(Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 1) ), __x)
367+
368+
@test isequal(Symbolics.postprocess_root(term(^, 0, __x)), 0)
369+
@test_broken isequal(Symbolics.postprocess_root(term(/, __x, 0)), Inf)
370+
@test Symbolics.postprocess_root(term(^, __x, 0) ) == 1
371+
@test Symbolics.postprocess_root(term(^, Base.MathConstants.e, 0) ) == 1
372+
@test Symbolics.postprocess_root(term(^, Base.MathConstants.pi, 1) ) == Base.MathConstants.pi
373+
@test isequal(Symbolics.postprocess_root(term(^, __x, 1) ), __x)
366374

367375
x = Symbolics.term(sqrt, 2)
368376
@test isequal(Symbolics.postprocess_root( expand((x + 1)^4) ), 17 + 12x)
@@ -426,7 +434,10 @@ end
426434
lhs = ia_solve(a*x^b + c, x)[1]
427435
lhs2 = symbolic_solve(a*x^b + c, x)[1]
428436
rhs = Symbolics.term(^, -c.val/a.val, 1/b.val)
429-
#@test isequal(lhs, rhs)
437+
@test_broken isequal(lhs, rhs)
438+
439+
@test isequal(symbolic_solve(2/x, x)[1], Inf)
440+
@test isequal(symbolic_solve(x^1.5, x)[1], 0)
430441

431442
lhs = symbolic_solve(log(a*x)-b,x)[1]
432443
@test isequal(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E)

0 commit comments

Comments
 (0)