Skip to content

Commit 942c56e

Browse files
Merge pull request #1335 from n0rbed/new_feature
An attempt at #29
2 parents 5d0d8fb + 63b17a0 commit 942c56e

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

src/Symbolics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ include("solver/polynomialization.jl")
209209
include("solver/attract.jl")
210210
include("solver/ia_main.jl")
211211
include("solver/main.jl")
212+
include("solver/ia_rules.jl")
212213
export symbolic_solve
213214

214215
function symbolics_to_sympy end

src/solver/ia_rules.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
function cross_multiply(eq)
2+
og_oper = operation(unwrap(eq))
3+
done = true
4+
loop_add = false
5+
6+
term_tm = 1
7+
if og_oper === (/)
8+
done = false
9+
args = arguments(unwrap(eq))
10+
eq = wrap(args[1])
11+
end
12+
13+
# do this until no / are present
14+
if og_oper === (+)
15+
while !loop_add
16+
args = arguments(unwrap(eq))
17+
loop_add = true
18+
19+
for arg in args
20+
!iscall(arg) && continue
21+
oper = operation(arg)
22+
if oper == (/)
23+
done = false
24+
loop_add = false
25+
term_tm *= wrap(arguments(arg)[2])
26+
end
27+
end
28+
args = [arg*term_tm for arg in args]
29+
eq = expand(Symbolics.term((+), unwrap.(args)...))
30+
term_tm = 1
31+
end
32+
end
33+
34+
if done
35+
return eq
36+
else
37+
return cross_multiply(eq)
38+
end
39+
end
40+
41+
function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
42+
@assert iscall(unwrap(eq))
43+
vars = Symbolics.get_variables(eq)
44+
vars = filter(v -> !isequal(v, s), vars)
45+
vars = wrap.(vars)
46+
47+
term_tm = 1
48+
49+
eq = cross_multiply(eq)
50+
coeffs, constant = polynomial_coeffs(eq, [s])
51+
eqs = wrap.(collect(values(coeffs)))
52+
53+
solve_multivar(eqs, vars, dropmultiplicity=dropmultiplicity, warns=warns)
54+
end
55+
56+
# an attempt at using ia_solve recursively.
57+
function find_v(eqs, v, vars)
58+
vars = filter(var -> !isequal(var, v), vars)
59+
n_eqs = deepcopy(eqs)
60+
61+
if isequal(n_eqs[1], 0)
62+
n_eqs = n_eqs[2:end]
63+
end
64+
65+
present_vars = Symbolics.get_variables(n_eqs[1])
66+
var = present_vars[1]
67+
if length(present_vars) == 1 && isequal(present_vars[1], v)
68+
return ia_solve(n_eqs[1], var)
69+
end
70+
if isequal(present_vars[1], v)
71+
var = present_vars[2]
72+
end
73+
74+
@info "" n_eqs[1]
75+
@info "" var
76+
sols = ia_solve(n_eqs[1], var)
77+
isequal(sols, nothing) && return []
78+
79+
for s in sols
80+
for j in 2:length(n_eqs)
81+
n_eqs[j] = substitute(n_eqs[j], Dict(var => s))
82+
end
83+
# other solutions are cut out here
84+
return find_v(n_eqs[2:end], v, vars)
85+
end
86+
87+
end

0 commit comments

Comments
 (0)