Skip to content

Commit d30b5d5

Browse files
NumericalPlan added
1 parent db554a9 commit d30b5d5

File tree

6 files changed

+268
-177
lines changed

6 files changed

+268
-177
lines changed

src/SymbolicNumericIntegration.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,21 @@ module SymbolicNumericIntegration
33
using SymbolicUtils
44
using SymbolicUtils: istree, operation, arguments
55
using Symbolics
6-
using Symbolics: value, get_variables, expand_derivatives
6+
using Symbolics: value, get_variables, expand_derivatives, coeff
77
using SymbolicUtils.Rewriters
88
using SymbolicUtils: exprtype, BasicSymbolic
99

1010
using DataDrivenDiffEq, DataDrivenSparse
1111

12+
struct NumericalPlan
13+
abstol::Float64
14+
radius::Float64
15+
complex_plane::Bool
16+
opt::DataDrivenDiffEq.AbstractDataDrivenAlgorithm
17+
end
18+
19+
default_plan() = NumericalPlan(1e-6, 5.0, true, STLSQ(exp.(-10:1:0)))
20+
1221
include("utils.jl")
1322
include("tree.jl")
1423
include("special.jl")

src/homotopy.jl

Lines changed: 72 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,14 @@ function generate_homotopy(eq, x)
9494

9595
for i in 1:length(ks)
9696
μ = u[i]
97-
h₁, ∂h₁ = apply_partial_int_rules(sub[μ], x)
98-
h₁ = substitute(h₁, sub)
97+
y, dy = apply_partial_int_rules(sub[μ], x)
98+
99+
y = substitute(y, sub)
100+
∂y = guard_zero(diff(dy, x))
99101

100102
for j in 1:ks[i]
101-
h = substitute((q / μ^j) /h₁, sub)
102-
S += expand((ω + h₁) *+ h))
103+
h = substitute((q / μ^j) /y, sub)
104+
S += expand((ω + y) *+ h))
103105
end
104106
end
105107

@@ -109,94 +111,86 @@ end
109111

110112
########################## Main Integration Rules ##################################
111113

112-
@syms 𝛷(x)
114+
@syms 𝛷(x, u)
113115

114116
partial_int_rules = [
115117
# trigonometric functions
116-
@rule 𝛷(sin(~x)) => (cos(~x) + si(~x), ~x)
117-
@rule 𝛷(cos(~x)) => (sin(~x) + ci(~x), ~x)
118-
@rule 𝛷(tan(~x)) => (log(cos(~x)), ~x)
119-
@rule 𝛷(csc(~x)) => (log(csc(~x) + cot(~x)) + log(sin(~x)), ~x)
120-
@rule 𝛷(sec(~x)) => (log(sec(~x) + tan(~x)) + log(cos(~x)), ~x)
121-
@rule 𝛷(cot(~x)) => (log(sin(~x)), ~x)
118+
@rule 𝛷(~x, sin(~u)) => (cos(~u) + si(~u), ~u)
119+
@rule 𝛷(~x, cos(~u)) => (sin(~u) + ci(~u), ~u)
120+
@rule 𝛷(~x, tan(~u)) => (log(cos(~u)), ~u)
121+
@rule 𝛷(~x, csc(~u)) => (log(csc(~u) + cot(~u)) + log(sin(~u)), ~u)
122+
@rule 𝛷(~x, sec(~u)) => (log(sec(~u) + tan(~u)) + log(cos(~u)), ~u)
123+
@rule 𝛷(~x, cot(~u)) => (log(sin(~u)), ~u)
122124
# hyperbolic functions
123-
@rule 𝛷(sinh(~x)) => (cosh(~x), ~x)
124-
@rule 𝛷(cosh(~x)) => (sinh(~x), ~x)
125-
@rule 𝛷(tanh(~x)) => (log(cosh(~x)), ~x)
126-
@rule 𝛷(csch(~x)) => (log(tanh(~x / 2)), ~x)
127-
@rule 𝛷(sech(~x)) => (atan(sinh(~x)), ~x)
128-
@rule 𝛷(coth(~x)) => (log(sinh(~x)), ~x)
125+
@rule 𝛷(~x, sinh(~u)) => (cosh(~u), ~u)
126+
@rule 𝛷(~x, cosh(~u)) => (sinh(~u), ~u)
127+
@rule 𝛷(~x, tanh(~u)) => (log(cosh(~u)), ~u)
128+
@rule 𝛷(~x, csch(~u)) => (log(tanh(~u / 2)), ~u)
129+
@rule 𝛷(~x, sech(~u)) => (atan(sinh(~u)), ~u)
130+
@rule 𝛷(~x, coth(~u)) => (log(sinh(~u)), ~u)
129131
# 1/trigonometric functions
130-
@rule 𝛷(1 / sin(~x)) => (log(csc(~x) + cot(~x)) + log(sin(~x)), ~x)
131-
@rule 𝛷(1 / cos(~x)) => (log(sec(~x) + tan(~x)) + log(cos(~x)), ~x)
132-
@rule 𝛷(1 / tan(~x)) => (log(sin(~x)) + log(tan(~x)), ~x)
133-
@rule 𝛷(1 / csc(~x)) => (cos(~x) + log(csc(~x)), ~x)
134-
@rule 𝛷(1 / sec(~x)) => (sin(~x) + log(sec(~x)), ~x)
135-
@rule 𝛷(1 / cot(~x)) => (log(cos(~x)) + log(cot(~x)), ~x)
132+
@rule 𝛷(~x, 1 / sin(~u)) => (log(csc(~u) + cot(~u)) + log(sin(~u)), ~u)
133+
@rule 𝛷(~x, 1 / cos(~u)) => (log(sec(~u) + tan(~u)) + log(cos(~u)), ~u)
134+
@rule 𝛷(~x, 1 / tan(~u)) => (log(sin(~u)) + log(tan(~u)), ~u)
135+
@rule 𝛷(~x, 1 / csc(~u)) => (cos(~u) + log(csc(~u)), ~u)
136+
@rule 𝛷(~x, 1 / sec(~u)) => (sin(~u) + log(sec(~u)), ~u)
137+
@rule 𝛷(~x, 1 / cot(~u)) => (log(cos(~u)) + log(cot(~u)), ~u)
136138
# 1/hyperbolic functions
137-
@rule 𝛷(1 / sinh(~x)) => (log(tanh(~x / 2)) + log(sinh(~x)), ~x)
138-
@rule 𝛷(1 / cosh(~x)) => (atan(sinh(~x)) + log(cosh(~x)), ~x)
139-
@rule 𝛷(1 / tanh(~x)) => (log(sinh(~x)) + log(tanh(~x)), ~x)
140-
@rule 𝛷(1 / csch(~x)) => (cosh(~x) + log(csch(~x)), ~x)
141-
@rule 𝛷(1 / sech(~x)) => (sinh(~x) + log(sech(~x)), ~x)
142-
@rule 𝛷(1 / coth(~x)) => (log(cosh(~x)) + log(coth(~x)), ~x)
139+
@rule 𝛷(~x, 1 / sinh(~u)) => (log(tanh(~u / 2)) + log(sinh(~u)), ~u)
140+
@rule 𝛷(~x, 1 / cosh(~u)) => (atan(sinh(~u)) + log(cosh(~u)), ~u)
141+
@rule 𝛷(~x, 1 / tanh(~u)) => (log(sinh(~u)) + log(tanh(~u)), ~u)
142+
@rule 𝛷(~x, 1 / csch(~u)) => (cosh(~u) + log(csch(~u)), ~u)
143+
@rule 𝛷(~x, 1 / sech(~u)) => (sinh(~u) + log(sech(~u)), ~u)
144+
@rule 𝛷(~x, 1 / coth(~u)) => (log(cosh(~u)) + log(coth(~u)), ~u)
143145
# inverse trigonometric functions
144-
@rule 𝛷(asin(~x)) => (~x * asin(~x) + sqrt(1 - ~x * ~x), ~x)
145-
@rule 𝛷(acos(~x)) => (~x * acos(~x) + sqrt(1 - ~x * ~x), ~x)
146-
@rule 𝛷(atan(~x)) => (~x * atan(~x) + log(~x * ~x + 1), ~x)
147-
@rule 𝛷(acsc(~x)) => (~x * acsc(~x) + atanh(1 - ^(~x, -2)), ~x)
148-
@rule 𝛷(asec(~x)) => (~x * asec(~x) + acosh(~x), ~x)
149-
@rule 𝛷(acot(~x)) => (~x * acot(~x) + log(~x * ~x + 1), ~x)
146+
@rule 𝛷(~x, asin(~u)) => (~u * asin(~u) + sqrt(1 - ~u * ~u), ~u)
147+
@rule 𝛷(~x, acos(~u)) => (~u * acos(~u) + sqrt(1 - ~u * ~u), ~u)
148+
@rule 𝛷(~x, atan(~u)) => (~u * atan(~u) + log(~u * ~u + 1), ~u)
149+
@rule 𝛷(~x, acsc(~u)) => (~u * acsc(~u) + atanh(1 - ^(~u, -2)), ~u)
150+
@rule 𝛷(~x, asec(~u)) => (~u * asec(~u) + acosh(~u), ~u)
151+
@rule 𝛷(~x, acot(~u)) => (~u * acot(~u) + log(~u * ~u + 1), ~u)
150152
# inverse hyperbolic functions
151-
@rule 𝛷(asinh(~x)) => (~x * asinh(~x) + sqrt(~x * ~x + 1), ~x)
152-
@rule 𝛷(acosh(~x)) => (~x * acosh(~x) + sqrt(~x * ~x - 1), ~x)
153-
@rule 𝛷(atanh(~x)) => (~x * atanh(~x) + log(~x + 1), ~x)
154-
@rule 𝛷(acsch(~x)) => (acsch(~x), ~x)
155-
@rule 𝛷(asech(~x)) => (asech(~x), ~x)
156-
@rule 𝛷(acoth(~x)) => (~x * acot(~x) + log(~x + 1), ~x)
153+
@rule 𝛷(~x, asinh(~u)) => (~u * asinh(~u) + sqrt(~u * ~u + 1), ~u)
154+
@rule 𝛷(~x, acosh(~u)) => (~u * acosh(~u) + sqrt(~u * ~u - 1), ~u)
155+
@rule 𝛷(~x, atanh(~u)) => (~u * atanh(~u) + log(~u + 1), ~u)
156+
@rule 𝛷(~x, acsch(~u)) => (acsch(~u), ~u)
157+
@rule 𝛷(~x, asech(~u)) => (asech(~u), ~u)
158+
@rule 𝛷(~x, acoth(~u)) => (~u * acot(~u) + log(~u + 1), ~u)
157159
# logarithmic and exponential functions
158-
@rule 𝛷(log(~x)) => (~x + ~x * log(~x) + sum(pow_minus_rule(~x, -1); init = one(~x)),
159-
~x);
160-
@rule 𝛷(1 / log(~x)) => (log(log(~x)) + li(~x), ~x)
161-
@rule 𝛷(exp(~x)) => (exp(~x) + ei(~x) + erfi_rule(~x), ~x)
162-
@rule 𝛷(^(exp(~x), ~k::is_neg)) => (^(exp(-~x), -~k), ~x)
160+
@rule 𝛷(~x, log(~u)) => (~u + ~u * log(~u) + sum(pow_minus_rule(~u, ~x, -1); init = one(~u)),
161+
~u);
162+
@rule 𝛷(~x, 1 / log(~u)) => (log(log(~u)) + li(~u), ~u)
163+
@rule 𝛷(~x, exp(~u)) => (exp(~u) + ei(~u) + erfi_(~x), ~u)
164+
@rule 𝛷(~x, ^(exp(~u), ~k::is_neg)) => (^(exp(-~u), -~k), ~u)
163165
# square-root functions
164-
@rule 𝛷(^(~x, ~k::is_abs_half)) => (sum(sqrt_rule(~x, ~k); init = one(~x)), ~x);
165-
@rule 𝛷(sqrt(~x)) => (sum(sqrt_rule(~x, 0.5); init = one(~x)), ~x);
166-
@rule 𝛷(1 / sqrt(~x)) => (sum(sqrt_rule(~x, -0.5); init = one(~x)), ~x);
166+
@rule 𝛷(~x, ^(~u, ~k::is_abs_half)) => (sum(sqrt_rule(~u, ~x, ~k); init = one(~u)), ~u);
167+
@rule 𝛷(~x, sqrt(~u)) => (sum(sqrt_rule(~u, ~x, 0.5); init = one(~u)), ~u);
168+
@rule 𝛷(~x, 1 / sqrt(~u)) => (sum(sqrt_rule(~u, ~x, -0.5); init = one(~u)), ~u);
167169
# rational functions
168-
@rule 𝛷(1 / ^(~x::is_univar_poly, ~k::is_pos_int)) => (sum(pow_minus_rule(~x, -~k);
169-
init = one(~x)),
170-
~x);
171-
@rule 𝛷(1 / ~x::is_univar_poly) => (sum(pow_minus_rule(~x, -1); init = one(~x)), ~x);
172-
@rule 𝛷(^(~x, -1)) => (log(~x), ~x)
173-
@rule 𝛷(^(~x, ~k::is_neg_int)) => (sum(^(~x, i) for i in (~k + 1):-1), ~x)
174-
@rule 𝛷(1 / ~x) => (log(~x), ~x)
175-
@rule 𝛷(^(~x, ~k::is_pos_int)) => (sum(^(~x, i + 1) for i in 1:(~k + 1)), ~x)
176-
@rule 𝛷(1) => (𝑥, 1)
177-
@rule 𝛷(~x) => ((~x + ^(~x, 2)), ~x)]
170+
@rule 𝛷(~x, 1 / ^(~u::is_univar_poly, ~k::is_pos_int)) => (sum(pow_minus_rule(~u, ~x, -~k);
171+
init = one(~u)),
172+
~u);
173+
@rule 𝛷(~x, 1 / ~u::is_univar_poly) => (sum(pow_minus_rule(~u, ~x, -1); init = one(~u)), ~u);
174+
@rule 𝛷(~x, ^(~u, -1)) => (log(~u) + ~u * log(~u), ~u)
175+
@rule 𝛷(~x, ^(~u, ~k::is_neg_int)) => (sum(^(~u, i) for i in (~k + 1):-1), ~u)
176+
@rule 𝛷(~x, 1 / ~u) => (log(~u), ~u)
177+
@rule 𝛷(~x, ^(~u, ~k::is_pos_int)) => (sum(^(~u, i + 1) for i in 1:(~k + 1)), ~u)
178+
@rule 𝛷(~x, 1) => (𝑥, 1)
179+
@rule 𝛷(~x, ~u) => ((~u + ^(~u, 2)), ~u)]
178180

179181
function apply_partial_int_rules(eq, x)
180-
y, dy = Chain(partial_int_rules)(𝛷(value(eq)))
181-
return y, guard_zero(diff(dy, x))
182+
y, dy = Chain(partial_int_rules)(𝛷(x, value(eq)))
183+
return y, dy
182184
end
183185

184186
################################################################
185187

186-
function erfi_rule(eq)
187-
if is_univar_poly(eq)
188-
x = var(eq)
189-
return erfi_(x)
190-
end
191-
return 0
192-
end
193-
194-
function pow_minus_rule(p, k; abstol = 1e-8)
188+
function pow_minus_rule(p, x, k; abstol = 1e-8)
195189
if !is_univar_poly(p)
196-
return [p^k, p^(k + 1), log(p)]
190+
return [p^k, p^(k + 1), log(p), p*log(p)]
197191
end
198192

199-
x = var(p)
193+
# x = var(p)
200194
d = poly_deg(p)
201195

202196
for j in 1:10 # will try 10 times to find the roots
@@ -229,14 +223,17 @@ function pow_minus_rule(p, k; abstol = 1e-8)
229223
end
230224
end
231225

232-
function sqrt_rule(p, k)
226+
function sqrt_rule(p, x, k)
233227
h = Any[p^k, p^(k + 1)]
228+
229+
Δ = diff(p, x)
230+
push!(h, log/2 + sqrt(p)))
234231

235232
if !is_univar_poly(p)
236233
return h
237234
end
238235

239-
x = var(p)
236+
# x = var(p)
240237

241238
if poly_deg(p) == 2
242239
r, s = find_roots(p, x)
@@ -255,7 +252,6 @@ function sqrt_rule(p, k)
255252
end
256253
end
257254

258-
Δ = expand_derivatives(Differential(x)(p))
259-
push!(h, log(0.5 * Δ + sqrt(p)))
260255
return h
261256
end
257+

src/integral.jl

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,22 @@ Output:
5353
- `unsolved`: the residual unsolved portion of the input
5454
- `err`: the numerical error in reaching the solution
5555
"""
56-
function integrate(eq, x = nothing; abstol = 1e-6, num_steps = 2, num_trials = 10,
56+
function integrate(eq, x = nothing;
57+
abstol = 1e-6,
58+
num_steps = 2,
59+
num_trials = 10,
5760
radius = 1.0,
58-
show_basis = false, opt = STLSQ(exp.(-10:1:0)), bypass = false,
59-
symbolic = false, max_basis = 100, verbose = false, complex_plane = true,
60-
homotopy = true, use_optim = false, detailed = true)
61+
show_basis = false,
62+
opt = STLSQ(exp.(-10:1:0)),
63+
bypass = false,
64+
symbolic = false,
65+
max_basis = 100,
66+
verbose = false,
67+
complex_plane = true,
68+
homotopy = true,
69+
use_optim = false,
70+
detailed = true)
71+
6172
eq = expand(eq)
6273

6374
if x == nothing
@@ -82,11 +93,12 @@ function integrate(eq, x = nothing; abstol = 1e-6, num_steps = 2, num_trials = 1
8293
return x * eq
8394
end
8495
end
96+
97+
plan = NumericalPlan(abstol, radius, complex_plane, opt)
98+
99+
s, u, ε = integrate_sum(eq, x; plan, bypass, num_trials, num_steps,
100+
show_basis, symbolic, max_basis, verbose, use_optim)
85101

86-
s, u, ε = integrate_sum(eq, x; bypass, abstol, num_trials, num_steps,
87-
radius, show_basis, opt, symbolic,
88-
max_basis, verbose, complex_plane, use_optim)
89-
90102
s = beautify(s)
91103

92104
if detailed
@@ -182,11 +194,11 @@ The output is the same as `integrate`
182194
"""
183195
function integrate_term(eq, x; kwargs...)
184196
args = Dict(kwargs)
185-
abstol, num_steps, num_trials, show_basis, symbolic, verbose, max_basis,
186-
radius = args[:abstol], args[:num_steps],
187-
args[:num_trials], args[:show_basis], args[:symbolic],
188-
args[:verbose],
189-
args[:max_basis], args[:radius]
197+
plan, num_steps, num_trials, show_basis, symbolic, verbose, max_basis =
198+
args[:plan], args[:num_steps], args[:num_trials], args[:show_basis],
199+
args[:symbolic], args[:verbose], args[:max_basis]
200+
201+
abstol = plan.abstol
190202

191203
if is_number(eq)
192204
y = eq * x
@@ -202,7 +214,7 @@ function integrate_term(eq, x; kwargs...)
202214
end
203215

204216
if symbolic
205-
y = integrate_symbolic(eq, x; abstol, radius)
217+
y = integrate_symbolic(eq, x; plan)
206218
if y == nothing
207219
if has_sym_consts
208220
@info("Symbolic integration failed. Try changing constant parameters ([$(join(params, ", "))]) to numerical values.")
@@ -246,11 +258,11 @@ function integrate_term(eq, x; kwargs...)
246258

247259
for j in 1:num_trials
248260
basis = isodd(j) ? basis1 : basis2
249-
r = radius
250-
y, ε = try_integrate(eq, x, basis, r; kwargs...)
261+
y, ε = try_integrate(eq, x, basis; plan)
262+
263+
ε = accept_solution(eq, x, y; plan)
251264

252-
ε = accept_solution(eq, x, y, r)
253-
if ε < abstol
265+
if ε < abstol
254266
return y, 0, ε
255267
elseif ε < εᵣ
256268
εᵣ = ε
@@ -259,8 +271,8 @@ function integrate_term(eq, x; kwargs...)
259271
end
260272

261273
if i < num_steps
262-
basis1, ok1 = expand_basis(prune_basis(eq, x, basis1, radius; kwargs...), x)
263-
basis2, ok2 = expand_basis(prune_basis(eq, x, basis2, radius; kwargs...), x)
274+
basis1, ok1 = expand_basis(prune_basis(eq, x, basis1; plan), x)
275+
basis2, ok2 = expand_basis(prune_basis(eq, x, basis2; plan), x)
264276

265277
if !ok1 && ~ok2
266278
break
@@ -278,7 +290,7 @@ end
278290
###############################################################################
279291

280292
"""
281-
try_integrate(eq, x, basis, radius; kwargs...)
293+
try_integrate(eq, x, basis; plan)
282294
283295
is the main dispatch point to call different sparse solvers. It tries to
284296
find a linear combination of the basis, whose derivative is equal to eq
@@ -288,19 +300,13 @@ output:
288300
- solved: the solved integration problem or 0 otherwise
289301
- err: the numerical error in reaching the solution
290302
"""
291-
function try_integrate(eq, x, basis, radius; kwargs...)
292-
args = Dict(kwargs)
293-
use_optim = args[:use_optim]
294-
303+
function try_integrate(eq, x, basis; plan = default_plan())
295304
if isempty(basis)
296305
return 0, Inf
297306
end
298307

299-
if use_optim
300-
return solve_optim(eq, x, basis, radius; kwargs...)
301-
else
302-
return solve_sparse(eq, x, basis, radius; kwargs...)
303-
end
308+
# return solve_optim(eq, x, basis; plan)
309+
return solve_sparse(eq, x, basis; plan)
304310
end
305311

306312
#################################################################################
@@ -310,10 +316,11 @@ end
310316
311317
is used for debugging and should not be called in the course of normal execution
312318
"""
313-
function integrate_basis(eq, x = var(eq); abstol = 1e-6, radius = 1.0, complex_plane = true)
319+
function integrate_basis(eq, x = var(eq); plan = default_plan())
314320
eq = cache(expand(eq))
315321
basis = generate_basis(eq, x, false)
316322
n = length(basis)
317-
A, X = init_basis_matrix(eq, x, basis, radius, complex_plane; abstol)
323+
A, X = init_basis_matrix(eq, x, basis; plan)
318324
return basis, A, X
319325
end
326+

src/numeric_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ end
1919

2020
accept_solution(eq::ExprCache, x, sol, radius) = accept_solution(expr(eq), x, sol, radius)
2121

22-
function accept_solution(eq, x, sol, radius)
22+
function accept_solution(eq, x, sol; plan = default_plan())
2323
try
24-
x₀ = test_point(true, radius)
25-
Δ = substitute(expand_derivatives(Differential(x)(sol) - eq), Dict(x => x₀))
24+
x₀ = test_point(plan.complex_plane, plan.radius)
25+
Δ = substitute(diff(sol, x) - expr(eq), Dict(x => x₀))
2626
return abs(Δ)
2727
catch e
2828
#

0 commit comments

Comments
 (0)