Skip to content

Commit 1546740

Browse files
before refactoring
1 parent d30b5d5 commit 1546740

File tree

6 files changed

+111
-160
lines changed

6 files changed

+111
-160
lines changed

src/integral.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function integrate(eq, x = nothing;
5757
abstol = 1e-6,
5858
num_steps = 2,
5959
num_trials = 10,
60-
radius = 1.0,
60+
radius = 5.0,
6161
show_basis = false,
6262
opt = STLSQ(exp.(-10:1:0)),
6363
bypass = false,
@@ -243,7 +243,6 @@ function integrate_term(eq, x; kwargs...)
243243
return 0, expr(eq), Inf
244244
end
245245

246-
# D = Differential(x)
247246
ε₀ = Inf
248247
y₀ = 0
249248

@@ -259,7 +258,6 @@ function integrate_term(eq, x; kwargs...)
259258
for j in 1:num_trials
260259
basis = isodd(j) ? basis1 : basis2
261260
y, ε = try_integrate(eq, x, basis; plan)
262-
263261
ε = accept_solution(eq, x, y; plan)
264262

265263
if ε < abstol

src/numeric_utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ accept_solution(eq::ExprCache, x, sol, radius) = accept_solution(expr(eq), x, so
2121

2222
function accept_solution(eq, x, sol; plan = default_plan())
2323
try
24-
x₀ = test_point(plan.complex_plane, plan.radius)
25-
Δ = substitute(diff(sol, x) - expr(eq), Dict(x => x₀))
24+
# x₀ = test_point(plan.complex_plane, plan.radius)
25+
# Δ = substitute(diff(sol, x) - expr(eq), Dict(x => x₀))
26+
S = subs_symbols(eq, x; include_x = true, plan.radius)
27+
Δ = substitute(diff(sol, x) - expr(eq), S)
2628
return abs(Δ)
2729
catch e
2830
#

src/rules.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ s_rules = [@rule Ω(+(~~xs)) => sum(map(Ω, ~~xs))
7979
@rule Ω(~x::is_linear_poly) => ~x
8080
@rule Ω(sin(~x::is_linear_poly)) => sin(~x)
8181
@rule Ω(cos(~x::is_linear_poly)) => cos(~x)
82+
@rule Ω(exp(~x + ~y)) => exp(~x) * exp(~y)
8283
@rule Ω(sinh(~x::is_linear_poly)) => sinh(~x)
8384
@rule Ω(cosh(~x::is_linear_poly)) => cosh(~x)
8485
@rule Ω(exp(~x::is_linear_poly)) => exp(~x)
@@ -274,7 +275,13 @@ h_rules = [@rule +(~~xs) => ω + sum(~~xs)
274275
# complexity returns a measure of the complexity of an equation
275276
# it is roughly similar ro kolmogorov complexity
276277
function complexity(eq)
277-
_, eq = ops(eq)
278-
h = Prewalk(PassThrough(Chain(h_rules)))(eq)
279-
return substitute(h, Dict=> 1))
278+
eq = value(eq)
279+
if istree(eq)
280+
return 1 + sum(complexity(t) for t in args(eq))
281+
elseif is_number(eq)
282+
return abs(eq)
283+
else
284+
return 1
285+
end
280286
end
287+

src/sparse.jl

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11

22

3-
function solve_sparse(eq, x, basis; plan = default_plan())
3+
function solve_sparse(eq, x, basis; plan = default_plan(), AX = nothing)
44
abstol = plan.abstol
55

6-
A, X = init_basis_matrix(eq, x, basis; plan)
6+
if AX == nothing
7+
A, X, V = init_basis_matrix(eq, x, basis; plan)
8+
else
9+
A, X, V = AX
10+
end
711

812
# find a linearly independent subset of the basis
913
l = find_independent_subset(A; abstol)
10-
A, basis = A[l, l], basis[l]
14+
A, V, basis = A[:,l], V[:,l], basis[l]
1115

12-
y₁, ε₁ = sparse_fit(A, basis; plan)
16+
y₁, ε₁ = sparse_fit(A, V, basis; plan)
1317
if ε₁ < abstol
14-
return y₁, ε₁, basis
18+
return y₁, ε₁
1519
end
1620

1721
rank = sum(l)
@@ -30,7 +34,7 @@ function solve_sparse(eq, x, basis; plan = default_plan())
3034

3135
# moving toward the poles
3236
modify_basis_matrix!(A, X, eq, x, basis)
33-
y₄, ε₄ = sparse_fit(A, basis; plan)
37+
y₄, ε₄ = sparse_fit(A, V, basis; plan)
3438

3539
if ε₄ < abstol || ε₄ < ε₁
3640
return y₄, ε₄
@@ -45,14 +49,13 @@ function prune_basis(eq, x, basis; plan = default_plan())
4549
return basis[l]
4650
end
4751

48-
function init_basis_matrix(eq, x, basis; plan = default_plan())
52+
function init_basis_matrix(eq, x, basis; plan = default_plan(), nv=1)
4953
n = length(basis)
5054
eq = value(eq)
5155
basis = value.(basis)
5256

53-
# A is an nxn matrix holding the values of the fragments at n random points
54-
A = zeros(Complex{Float64}, (n, n))
55-
X = zeros(Complex{Float64}, n)
57+
A = zeros(Complex{Float64}, (n+nv, n))
58+
X = zeros(Complex{Float64}, n+nv)
5659

5760
S = subs_symbols(eq, x)
5861
if !isempty(S)
@@ -64,9 +67,9 @@ function init_basis_matrix(eq, x, basis; plan = default_plan())
6467
Δbasis_fun = deriv_fun!.(basis, x)
6568

6669
k = 1
67-
l = 10 * n# max attempt
70+
l = 10 * (n+nv) # max attempt
6871

69-
while k <= n && l > 0
72+
while k <= n+nv && l > 0
7073
try
7174
x₀ = test_point(plan.complex_plane, plan.radius)
7275
X[k] = x₀
@@ -86,7 +89,7 @@ function init_basis_matrix(eq, x, basis; plan = default_plan())
8689
l -= 1
8790
end
8891

89-
return A, X
92+
return A[1:n,:], X[1:n], A[n+1:end,:]
9093
end
9194

9295
function modify_basis_matrix!(A, X, eq, x, basis)
@@ -113,7 +116,7 @@ function DataDrivenSparse.active_set!(idx::BitMatrix, p::SoftThreshold,
113116
DataDrivenSparse.active_set!(idx, p, abs.(x), λ)
114117
end
115118

116-
function sparse_fit(A, basis; plan = default_plan())
119+
function sparse_fit(A, V, basis; plan = default_plan())
117120
n, m = size(A)
118121

119122
try
@@ -124,16 +127,18 @@ function sparse_fit(A, basis; plan = default_plan())
124127
res, _... = solver(A', b)
125128
q₀ = DataDrivenSparse.coef(first(res))
126129

127-
ε = rms(A * q₀' .- 1)
130+
ε = rms(V * q₀' .- 1)
128131
q = nice_parameter.(q₀)
132+
129133
if sum(iscomplex.(q)) > 2
130134
return nothing, Inf
131135
end # eliminating complex coefficients
132-
return sum(q[i] * expr(basis[i]) for i in 1:length(basis) if q[i] != 0;
133-
init = 0),
134-
abs(ε)
136+
137+
sol = sum(q[i] * expr(basis[i]) for i in 1:length(basis) if q[i] != 0; init = 0)
138+
return sol, abs(ε)
139+
135140
catch e
136-
println("Error from sparse_fit", e)
141+
println("Error from sparse_fit: ", e)
137142
return nothing, Inf
138143
end
139144
end
@@ -175,10 +180,10 @@ end
175180

176181
function hints(eq, x, basis; plan = default_plan())
177182
abstol = plan.abstol
178-
A, X = init_basis_matrix(eq, x, basis; plan)
183+
A, X, V = init_basis_matrix(eq, x, basis; plan)
179184
# find a linearly independent subset of the basis
180185
l = find_independent_subset(A; abstol)
181-
A, basis = A[l, l], basis[l]
186+
A, V, basis = A[:,l], V[:,l], basis[l]
182187

183188
n, m = size(A)
184189

@@ -189,14 +194,17 @@ function hints(eq, x, basis; plan = default_plan())
189194
maxiters = 1000))
190195
res, _... = solver(A', b)
191196
q = DataDrivenSparse.coef(first(res))
192-
err = abs(rms(A * q' .- 1))
193-
if err < abstol
197+
198+
ε = abs(rms(V * q' .- 1))
199+
200+
if ε < abstol
194201
sel = abs.(q) .> abstol
195-
h = [basis[i] for i in 1:length(basis) if sel[i]]
202+
h = [basis[i] for i in 1:length(basis) if sel[i]]
196203
else
197204
h = []
198205
end
199-
return h, err
206+
207+
return h, ε
200208
catch e
201209
# println("Error from hints: ", e)
202210
end
@@ -211,8 +219,10 @@ function best_hints(eq, x, basis; plan = default_plan(), num_trials = 10)
211219
for _ in 1:num_trials
212220
try
213221
h, err = hints(eq, x, basis; plan)
214-
push!(H, h)
215-
push!(L, err < plan.abstol ? length(h) : length(basis))
222+
if err < plan.abstol
223+
push!(H, h)
224+
push!(L, length(h))
225+
end
216226
catch e
217227
#
218228
end

0 commit comments

Comments
 (0)