Skip to content

Commit 739e6f1

Browse files
Merge remote-tracking branch 'origin/master'
2 parents 4e6edbd + ab6a057 commit 739e6f1

File tree

6 files changed

+172
-129
lines changed

6 files changed

+172
-129
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqDevTools"
22
uuid = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.12.0"
4+
version = "2.15.0"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

src/benchmark.jl

Lines changed: 133 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -36,35 +36,46 @@ function Shootout(prob,setups;appxsol=nothing,names=nothing,error_estimate=:fina
3636
effratios = Matrix{Float64}(undef,N,N)
3737
timeseries_errors = error_estimate TIMESERIES_ERRORS
3838
dense_errors = error_estimate DENSE_ERRORS
39-
if names == nothing
39+
if names === nothing
4040
names = [string(nameof(typeof(setup[:alg]))) for setup in setups]
4141
end
4242
for i in eachindex(setups)
4343
sol = solve(prob,setups[i][:alg];timeseries_errors=timeseries_errors,
44-
dense_errors = dense_errors,kwargs...,setups[i]...) # Compile and get result
44+
dense_errors = dense_errors,kwargs...,setups[i]...)
4545

46-
if appxsol != nothing
47-
errsol = appxtrue(sol,appxsol)
46+
if :prob_choice keys(setups[i])
47+
cur_appxsol = appxsol[setups[i][:prob_choice]]
48+
else
49+
cur_appxsol = appxsol
50+
end
51+
52+
if cur_appxsol != cur_appxsol
53+
errsol = appxtrue(sol,cur_appxsol)
4854
errors[i] = errsol.errors[error_estimate]
4955
solutions[i] = errsol
5056
else
5157
errors[i] = sol.errors[error_estimate]
5258
solutions[i] = sol
5359
end
5460

55-
benchmark_f = let prob=prob,alg=setups[i][:alg],sol=sol,kwargs=kwargs
56-
function benchmark_f()
57-
@elapsed solve(prob,alg,(sol.u),(sol.t),(sol.k);
58-
timeseries_errors = false,
59-
dense_errors = false, kwargs...)
60-
end
61+
if haskey(setups[i], :prob_choice)
62+
_prob = prob[setups[i][:prob_choice]]
63+
else
64+
_prob = prob
65+
end
66+
67+
benchmark_f = let _prob=_prob,alg=setups[i][:alg],sol=sol,kwargs=kwargs
68+
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
69+
timeseries_errors = false,
70+
dense_errors = false, kwargs...)
6171
end
72+
benchmark_f() # pre-compile
6273

63-
b_t = benchmark_f()
74+
b_t = benchmark_f()
6475
if b_t > seconds
6576
times[i] = b_t
6677
else
67-
times[i] = minimum([b_t;map(i->benchmark_f(),2:numruns)])
78+
times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t)
6879
end
6980

7081
effs[i] = 1/(errors[i]*times[i])
@@ -86,10 +97,10 @@ function ShootoutSet(probs,setups;probaux=nothing,
8697
N = length(probs)
8798
shootouts = Vector{Shootout}(undef,N)
8899
winners = Vector{String}(undef,N)
89-
if names == nothing
100+
if names === nothing
90101
names = [string(nameof(typeof(setup[:alg]))) for setup in setups]
91102
end
92-
if probaux == nothing
103+
if probaux === nothing
93104
probaux = Vector{Dict{Symbol,Any}}(undef,N)
94105
for i in 1:N
95106
probaux[i] = Dict{Symbol,Any}()
@@ -143,7 +154,6 @@ mutable struct WorkPrecisionSet
143154
prob
144155
setups
145156
names
146-
sample_error
147157
error_estimate
148158
numruns
149159
end
@@ -153,68 +163,80 @@ function WorkPrecision(prob,alg,abstols,reltols,dts=nothing;
153163
N = length(abstols)
154164
errors = Vector{Float64}(undef,N)
155165
times = Vector{Float64}(undef,N)
156-
if name == nothing
166+
if name === nothing
157167
name = "WP-Alg"
158168
end
159-
timeseries_errors = error_estimate TIMESERIES_ERRORS
160-
dense_errors = error_estimate DENSE_ERRORS
161-
for i in 1:N
162-
# Calculate errors and precompile
163-
if dts == nothing
164-
sol = solve(prob,alg;kwargs...,abstol=abstols[i],
165-
reltol=reltols[i],timeseries_errors=timeseries_errors,
166-
dense_errors = dense_errors) # Compile and get result
167-
else
168-
sol = solve(prob,alg;kwargs...,abstol=abstols[i],
169-
reltol=reltols[i],dt=dts[i],timeseries_errors=timeseries_errors,
170-
dense_errors = dense_errors) # Compile and get result
171-
end
172169

173-
if appxsol != nothing
174-
errsol = appxtrue(sol,appxsol)
175-
errors[i] = mean(errsol.errors[error_estimate])
176-
else
177-
errors[i] = mean(sol.errors[error_estimate])
178-
end
170+
if haskey(kwargs, :prob_choice)
171+
_prob = prob[kwargs[:prob_choice]]
172+
else
173+
_prob = prob
174+
end
179175

180-
benchmark_f = let dts=dts,prob=prob,alg=alg,sol=sol,abstols=abstols,reltols=reltols,kwargs=kwargs
181-
function benchmark_f()
182-
if dts == nothing
183-
@elapsed solve(prob,alg,(sol.u),(sol.t),(sol.k);
184-
abstol=(abstols[i]),
185-
reltol=(reltols[i]),
186-
timeseries_errors = false,
187-
dense_errors = false, kwargs...)
176+
let _prob = _prob
177+
timeseries_errors = error_estimate TIMESERIES_ERRORS
178+
dense_errors = error_estimate DENSE_ERRORS
179+
for i in 1:N
180+
if dts === nothing
181+
sol = solve(_prob,alg;kwargs...,abstol=abstols[i],
182+
reltol=reltols[i],timeseries_errors=timeseries_errors,
183+
dense_errors = dense_errors)
184+
else
185+
sol = solve(_prob,alg;kwargs...,abstol=abstols[i],
186+
reltol=reltols[i],dt=dts[i],timeseries_errors=timeseries_errors,
187+
dense_errors = dense_errors)
188+
end
189+
190+
if haskey(kwargs, :prob_choice)
191+
cur_appxsol = appxsol[kwargs[:prob_choice]]
192+
else
193+
cur_appxsol = appxsol
194+
end
195+
196+
if cur_appxsol !== nothing
197+
errsol = appxtrue(sol,cur_appxsol)
198+
errors[i] = mean(errsol.errors[error_estimate])
199+
else
200+
errors[i] = mean(sol.errors[error_estimate])
201+
end
202+
203+
benchmark_f = let dts=dts,_prob=_prob,alg=alg,sol=sol,abstols=abstols,reltols=reltols,kwargs=kwargs
204+
if dts === nothing
205+
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
206+
abstol = abstols[i],
207+
reltol = reltols[i],
208+
timeseries_errors = false,
209+
dense_errors = false, kwargs...)
188210
else
189-
@elapsed solve(prob,alg,(sol.u),(sol.t),(sol.k);
190-
abstol=(abstols[i]),
191-
reltol=(reltols[i]),
192-
dt=(dts[i]),
193-
timeseries_errors = false,
194-
dense_errors = false, kwargs...)
211+
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
212+
abstol = abstols[i],
213+
reltol = reltols[i],
214+
dt = dts[i],
215+
timeseries_errors = false,
216+
dense_errors = false, kwargs...)
195217
end
196218
end
197-
end
219+
benchmark_f() # pre-compile
198220

199-
b_t = benchmark_f()
200-
if b_t > seconds
201-
times[i] = b_t
202-
else
203-
times[i] = minimum([b_t;map(i->benchmark_f(),2:numruns)])
221+
b_t = benchmark_f()
222+
if b_t > seconds
223+
times[i] = b_t
224+
else
225+
times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t)
226+
end
204227
end
205228
end
206229
return WorkPrecision(prob,abstols,reltols,errors,times,name,N)
207230
end
208231

209-
function WorkPrecisionSet(prob::Union{AbstractODEProblem,AbstractDDEProblem,
210-
AbstractDAEProblem},
232+
function WorkPrecisionSet(prob,
211233
abstols,reltols,setups;
212234
print_names=false,names=nothing,appxsol=nothing,
213235
error_estimate=:final,
214236
test_dt=nothing,kwargs...)
215237
N = length(setups)
216238
wps = Vector{WorkPrecision}(undef,N)
217-
if names == nothing
239+
if names === nothing
218240
names = [string(nameof(typeof(setup[:alg]))) for setup in setups]
219241
end
220242
for i in 1:N
@@ -231,7 +253,7 @@ function WorkPrecisionSet(prob::Union{AbstractODEProblem,AbstractDDEProblem,
231253
name=names[i],kwargs...,setups[i]...)
232254
end
233255
end
234-
return WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,nothing,error_estimate,nothing)
256+
return WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,error_estimate,nothing)
235257
end
236258

237259
@def error_calculation begin
@@ -285,7 +307,8 @@ end
285307
function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_dt=nothing;
286308
numruns=20,numruns_error = 20,
287309
print_names=false,names=nothing,appxsol_setup=nothing,
288-
error_estimate=:final,parallel_type = :none,kwargs...)
310+
error_estimate=:final,parallel_type = :none,
311+
kwargs...)
289312

290313
timeseries_errors = DiffEqBase.has_analytic(prob.f) && error_estimate TIMESERIES_ERRORS
291314
weak_timeseries_errors = error_estimate WEAK_TIMESERIES_ERRORS
@@ -294,7 +317,7 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
294317
N = length(setups); M = length(abstols)
295318
times = Array{Float64}(undef,M,N)
296319
tmp_solutions = Array{Any}(undef,numruns_error,M,N)
297-
if names == nothing
320+
if names === nothing
298321
names = [string(nameof(typeof(setup[:alg]))) for setup in setups]
299322
end
300323
time_tmp = Vector{Float64}(undef,numruns)
@@ -310,8 +333,7 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
310333
@error_calculation
311334
end
312335
end
313-
analytical_solution_ends = [tmp_solutions[i,1,1].u_analytic[end] for i in 1:numruns_error]
314-
sample_error = 1.96std(norm.(analytical_solution_ends))/sqrt(numruns_error)
336+
315337
_solutions_k = [[EnsembleSolution(tmp_solutions[:,j,k],0.0,true) for j in 1:M] for k in 1:N]
316338
solutions = [[DiffEqBase.calculate_ensemble_errors(sim;weak_timeseries_errors=weak_timeseries_errors,weak_dense_errors=weak_dense_errors) for sim in sol_k] for sol_k in _solutions_k]
317339
if error_estimate WEAK_ERRORS
@@ -365,46 +387,62 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
365387
end
366388

367389
wps = [WorkPrecision(prob,abstols,reltols,errors[i],times[:,i],names[i],N) for i in 1:N]
368-
WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,sample_error,error_estimate,numruns_error)
390+
WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,error_estimate,numruns_error)
369391
end
370392

371-
@def sample_errors begin
372-
if !DiffEqBase.has_analytic(prob.f)
373-
true_sol = solve(prob,appxsol_setup[:alg];kwargs...,appxsol_setup...,
374-
save_everystep=false)
375-
analytical_solution_ends[i] = norm(true_sol.u[end])
376-
else
377-
_dt = prob.tspan[2] - prob.tspan[1]
378-
if typeof(prob.u0) <: Number
379-
W = sqrt(_dt)*randn()
380-
else
381-
W = sqrt(_dt)*randn(size(prob.u0))
393+
function get_sample_errors(prob::AbstractRODEProblem,setup,test_dt=nothing;
394+
appxsol_setup=nothing,
395+
numruns,error_estimate=:final,
396+
sample_error_runs = Int(1e7),
397+
solution_runs,
398+
parallel_type = :none,kwargs...)
399+
400+
maxnumruns = findmax(numruns)[1]
401+
402+
tmp_solutions_full = map(1:solution_runs) do i
403+
@info "Solution Run: $i"
404+
# Use the WorkPrecision stuff to calculate the errors
405+
tmp_solutions = Array{Any}(undef,maxnumruns,1,1)
406+
setups = [setup]
407+
abstols = [1e-2] # Standard default
408+
reltols = [1e-2] # Standard default
409+
M = 1; N = 1
410+
timeseries_errors = false; dense_errors = false
411+
if parallel_type == :threads
412+
Threads.@threads for i in 1:maxnumruns
413+
@error_calculation
414+
end
415+
elseif parallel_type == :none
416+
for i in 1:maxnumruns
417+
@error_calculation
418+
end
382419
end
383-
analytical_solution_ends[i] = norm(prob.f.analytic(prob.u0,prob.p,prob.tspan[2],W))
420+
tmp_solutions = vec(tmp_solutions)
384421
end
385-
end
386422

387-
function get_sample_errors(prob::AbstractRODEProblem,test_dt=nothing;
388-
appxsol_setup=nothing,
389-
numruns=20,std_estimation_runs = maximum(numruns),
390-
error_estimate=:final,parallel_type = :none,kwargs...)
391-
_std_estimation_runs = Int(std_estimation_runs)
392-
analytical_solution_ends = Vector{typeof(norm(prob.u0))}(undef,_std_estimation_runs)
393-
if parallel_type == :threads
394-
Threads.@threads for i in 1:_std_estimation_runs
395-
@sample_errors
396-
end
397-
elseif parallel_type == :none
398-
for i in 1:_std_estimation_runs
399-
@info "Standard deviation estimation: $i/$_std_estimation_runs"
400-
@sample_errors
423+
if DiffEqBase.has_analytic(prob.f)
424+
analytical_mean_end = mean(1:sample_error_runs) do i
425+
_dt = prob.tspan[2] - prob.tspan[1]
426+
if typeof(prob.u0) <: Number
427+
W = sqrt(_dt)*randn()
428+
else
429+
W = sqrt(_dt)*randn(size(prob.u0))
430+
end
431+
prob.f.analytic(prob.u0,prob.p,prob.tspan[2],W)
401432
end
433+
else
434+
# Use the mean of the means as the analytical mean
435+
analytical_mean_end = mean(mean(tmp_solutions[i].u[end] for i in 1:length(tmp_solutions)) for tmp_solutions in tmp_solutions_full)
402436
end
403-
est_std = std(analytical_solution_ends)
404-
if typeof(numruns) <: Number
405-
return 1.96est_std/sqrt(numruns)
437+
438+
if numruns isa Number
439+
mean_solution_ends = [mean([tmp_solutions[i].u[end] for i in 1:maxnumruns]) for tmp_solutions in tmp_solutions_full]
440+
return sample_error = 1.96std(norm(mean_sol_end - analytical_mean_end) for mean_sol_end in mean_solution_ends)/sqrt(numruns)
406441
else
407-
return [1.96est_std/sqrt(_numruns) for _numruns in numruns]
442+
map(1:length(numruns)) do i
443+
mean_solution_ends = [mean([tmp_solutions[i].u[end] for i in 1:numruns[i]]) for tmp_solutions in tmp_solutions_full]
444+
sample_error = 1.96std(norm(mean_sol_end - analytical_mean_end) for mean_sol_end in mean_solution_ends)/sqrt(numruns[i])
445+
end
408446
end
409447
end
410448

src/convergence.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function ConvergenceSimulation(solutions,convergence_axis;
3636
end
3737

3838
function test_convergence(dts::AbstractArray,prob::Union{AbstractRODEProblem,AbstractSDEProblem},
39-
alg;trajectories=10000,save_everystep=true,timeseries_steps=1,
39+
alg;trajectories,save_everystep=true,timeseries_steps=1,
4040
timeseries_errors=save_everystep,adaptive=false,
4141
weak_timeseries_errors=false,weak_dense_errors=false,kwargs...)
4242
N = length(dts)
@@ -66,8 +66,13 @@ function analyticless_test_convergence(dts::AbstractArray,
6666
for j in 1:trajectories
6767
@info "Monte Carlo iteration: $j/$trajectories"
6868
t = prob.tspan[1]:test_dt:prob.tspan[2]
69-
brownian_values = cumsum([[zeros(size(prob.u0))];[sqrt(test_dt)*randn(size(prob.u0)) for i in 1:length(t)-1]])
70-
brownian_values2 = cumsum([[zeros(size(prob.u0))];[sqrt(test_dt)*randn(size(prob.u0)) for i in 1:length(t)-1]])
69+
if prob.noise_rate_prototype === nothing
70+
brownian_values = cumsum([[zeros(size(prob.u0))];[sqrt(test_dt)*randn(size(prob.u0)) for i in 1:length(t)-1]])
71+
brownian_values2 = cumsum([[zeros(size(prob.u0))];[sqrt(test_dt)*randn(size(prob.u0)) for i in 1:length(t)-1]])
72+
else
73+
brownian_values = cumsum([[zeros(size(prob.noise_rate_prototype,2))];[sqrt(test_dt)*randn(size(prob.noise_rate_prototype,2)) for i in 1:length(t)-1]])
74+
brownian_values2 = cumsum([[zeros(size(prob.noise_rate_prototype,2))];[sqrt(test_dt)*randn(size(prob.noise_rate_prototype,2)) for i in 1:length(t)-1]])
75+
end
7176
np = NoiseGrid(t,brownian_values,brownian_values2)
7277
_prob = SDEProblem(prob.f,prob.g,prob.u0,prob.tspan,
7378
noise=np,

0 commit comments

Comments
 (0)