Skip to content

Commit db554a9

Browse files
new expand_basis added
1 parent 1408768 commit db554a9

File tree

4 files changed

+81
-41
lines changed

4 files changed

+81
-41
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicNumericIntegration"
22
uuid = "78aadeae-fbc0-11eb-17b6-c7ec0477ba9e"
33
authors = ["Shahriar Iravanian <[email protected]>"]
4-
version = "1.2.0"
4+
version = "1.2.1"
55

66
[deps]
77
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"

src/integral.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ function integrate(eq, x = nothing; abstol = 1e-6, num_steps = 2, num_trials = 1
8686
s, u, ε = integrate_sum(eq, x; bypass, abstol, num_trials, num_steps,
8787
radius, show_basis, opt, symbolic,
8888
max_basis, verbose, complex_plane, use_optim)
89+
90+
s = beautify(s)
8991

9092
if detailed
9193
return s, u, ε

src/sparse.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function solve_sparse(eq, x, basis, radius; kwargs...)
99
l = find_independent_subset(A; abstol)
1010
A, basis = A[l, l], basis[l]
1111

12-
y₁, ε₁ = sparse_fit(A, basis, opt; abstol)
12+
y₁, ε₁ = sparse_fit(A, basis; abstol, opt)
1313
if ε₁ < abstol
1414
return y₁, ε₁, basis
1515
end
@@ -30,7 +30,7 @@ function solve_sparse(eq, x, basis, radius; kwargs...)
3030

3131
# moving toward the poles
3232
modify_basis_matrix!(A, X, eq, x, basis, radius; abstol)
33-
y₄, ε₄ = sparse_fit(A, basis, opt; abstol)
33+
y₄, ε₄ = sparse_fit(A, basis; abstol, opt)
3434

3535
if ε₄ < abstol || ε₄ < ε₁
3636
return y₄, ε₄
@@ -49,6 +49,8 @@ end
4949

5050
function init_basis_matrix(eq, x, basis, radius, complex_plane; abstol = 1e-6)
5151
n = length(basis)
52+
eq = value(eq)
53+
basis = value.(basis)
5254

5355
# A is an nxn matrix holding the values of the fragments at n random points
5456
A = zeros(Complex{Float64}, (n, n))
@@ -81,7 +83,7 @@ function init_basis_matrix(eq, x, basis, radius, complex_plane; abstol = 1e-6)
8183
end
8284
end
8385
catch e
84-
println("Error from init_basis_matrix!: ", e)
86+
println("Error from init_basis_matrix!: ", e)
8587
end
8688
l -= 1
8789
end
@@ -113,7 +115,7 @@ function DataDrivenSparse.active_set!(idx::BitMatrix, p::SoftThreshold,
113115
DataDrivenSparse.active_set!(idx, p, abs.(x), λ)
114116
end
115117

116-
function sparse_fit(A, basis, opt; abstol = 1e-6)
118+
function sparse_fit(A, basis; abstol = 1e-6, opt = STLSQ(exp.(-10:1:0)))
117119
n, m = size(A)
118120

119121
try

src/symbolic.jl

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function equiv(y, x)
3131
if is_add(y)
3232
return sum(equiv(u, x) for u in args(y))
3333
elseif is_mul(y)
34-
return prod(isdependent(u, x) ? u : 1 for u in args(y))
34+
return prod(isdependent(u, x) ? u : 1 for u in args(y))
3535
elseif is_div(y)
3636
return expand_fraction(y, x)
3737
elseif is_number(y)
@@ -132,21 +132,22 @@ end
132132
Returns:
133133
a dict of sym : value pairs
134134
"""
135-
function subs_symbols(eq, x; include_x = false, radius = 5.0)
135+
function subs_symbols(eq, x; include_x = false, radius = 5.0, as_complex=true)
136136
S = Dict()
137137
for v in get_variables(value(eq))
138138
if !isequal(v, x)
139-
S[v] = randn()
139+
S[v] = as_complex ? Complex(randn()) : randn()
140140
end
141141
end
142142

143143
if include_x
144-
S[x] = randn()
144+
S[x] = as_complex ? Complex(randn()) : randn()
145145
end
146146

147147
return S
148148
end
149149

150+
150151
"""
151152
Splits terms into a part dependent on x and a part constant w.r.t. x
152153
For example, `atomize(a*sin(b*x), x)` is `(a, sin(b*x))`)
@@ -209,8 +210,8 @@ end
209210
function make_eqs(eq, x, basis)
210211
frags = Dict()
211212

212-
for d in terms(eq)
213-
c, a = atomize(d, x)
213+
for term in terms(eq)
214+
c, a = atomize(term, x)
214215
frags[a] = c
215216
end
216217

@@ -219,8 +220,8 @@ function make_eqs(eq, x, basis)
219220
for (i, b) in enumerate(basis)
220221
db = expand_fraction(diff(b, x), x)
221222

222-
for d in terms(db)
223-
c, a = atomize(d, x)
223+
for term in terms(db)
224+
c, a = atomize(term, x)
224225
frags[a] = get(frags, a, 0) + c * θ[i]
225226
end
226227

@@ -277,7 +278,14 @@ function make_square(eq, x, vars, frags)
277278
for i in 1:n
278279
S = subs_symbols(eq, x; include_x = true)
279280
q = sum(substitute(k, S) * v for (k, v) in frags)
280-
push!(eqs, q ~ 0)
281+
Q = (q ~ 0)
282+
if Q isa Array
283+
# Q returns a complex array
284+
# a different pathway is needed here!
285+
return nothing
286+
else
287+
push!(eqs, Q)
288+
end
281289
end
282290

283291
return eqs
@@ -303,29 +311,6 @@ function apply_coefs(q, ker, x)
303311
return beautify(c) * a
304312
end
305313

306-
function solve_eqs(eq, x, ker, eqs, vars; abstol = 1e-6, radius = 1.0, verbose = false)
307-
try
308-
A, b = make_Ab(eqs, vars)
309-
q = A \ b
310-
q = value.(q)
311-
sol = apply_coefs(q, ker, x)
312-
313-
# test if sol solves ∫ eq dx
314-
S = subs_symbols(eq, x; include_x = true, radius)
315-
err = substitute(diff(sol, x) - eq, S)
316-
317-
if abs(err) < abstol
318-
return sol
319-
end
320-
catch e
321-
if verbose
322-
@warn(e)
323-
end
324-
end
325-
326-
return nothing
327-
end
328-
329314
"""
330315
The main entry point for symbolic integration.
331316
@@ -336,31 +321,49 @@ end
336321
Returns:
337322
the integral or nothing if no solution
338323
"""
339-
function integrate_symbolic(eq, x; abstol = 1e-6, radius = 1.0, verbose = false)
324+
function integrate_symbolic(eq, x; abstol = 1e-6, radius = 1.0, verbose = false, num_steps=2)
340325
eq = expand(eq)
326+
coef, eq = atomize(eq, x)
341327

342328
if is_holonomic(eq, x)
343329
basis = blender(eq, x)
344330
else
345331
basis = generate_homotopy(eq, x)
346332
end
333+
334+
for k = 1:num_steps
335+
sol = try_symbolic(eq, x, basis; abstol, radius, verbose)
336+
337+
if sol != nothing
338+
return coef * sol
339+
end
340+
341+
if k < num_steps
342+
basis = expand_basis_symbolic(basis, x)
343+
end
344+
end
345+
346+
return nothing
347+
end
347348

349+
350+
function try_symbolic(eq, x, basis; abstol = 1e-6, radius = 1.0, verbose = false)
348351
ker = best_hints(eq, x, basis)
349352

350353
if ker == nothing
351354
return nothing
352355
end
353356

354357
ker = [atomize(y, x)[2] for y in ker]
355-
356358
eqs, vars, frags = make_eqs(eq, x, ker)
357-
358359
sol = solve_eqs(eq, x, ker, eqs, vars; abstol, radius, verbose)
359360

360361
if sol == nothing
361362
try
362363
eqs = make_square(eq, x, vars, frags)
363-
sol = solve_eqs(eq, x, ker, eqs, vars; abstol, radius, verbose)
364+
if eqs != nothing
365+
sol = solve_eqs(eq, x, ker, eqs, vars; abstol, radius, verbose=false)
366+
end
364367
catch e
365368
if verbose
366369
@warn(e)
@@ -370,3 +373,36 @@ function integrate_symbolic(eq, x; abstol = 1e-6, radius = 1.0, verbose = false)
370373

371374
return sol
372375
end
376+
377+
378+
function solve_eqs(eq, x, ker, eqs, vars; abstol = 1e-6, radius = 1.0, verbose = false)
379+
try
380+
A, b = make_Ab(eqs, vars)
381+
q = A \ b
382+
q = value.(q)
383+
sol = apply_coefs(q, ker, x)
384+
385+
# test if sol solves ∫ eq dx
386+
S = subs_symbols(eq, x; include_x = true, radius)
387+
err = substitute(diff(sol, x) - eq, S)
388+
389+
if abs(err) < abstol
390+
return sol
391+
end
392+
catch e
393+
if verbose
394+
@warn(e)
395+
end
396+
end
397+
398+
return nothing
399+
end
400+
401+
402+
function expand_basis_symbolic(basis, x)
403+
b = sum(basis)
404+
basis = split_terms(expand((1+x)*(b + diff(b, x))), x)
405+
406+
return basis
407+
end
408+

0 commit comments

Comments
 (0)