Skip to content

Commit 4d01575

Browse files
Merge pull request #75 from frankschae/ensemble_work_precision
Work Precision Set for EnsembleProblem
2 parents 2da3610 + 5a0cfb6 commit 4d01575

File tree

4 files changed

+220
-4
lines changed

4 files changed

+220
-4
lines changed

.travis.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@ julia:
1111
# - julia: nightly
1212
notifications:
1313
email: false
14-
# uncomment the following lines to override the default test script
15-
#script:
16-
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
17-
# - julia -e 'Pkg.clone(pwd()); Pkg.build("DiffEqDevTools"); Pkg.test("DiffEqDevTools"; coverage=true)'
14+
script: julia -e 'using Pkg; Pkg.build(); Pkg.test(coverage=false)'
1815
after_success:
1916
# push coverage results to Coveralls
2017
- julia -e 'cd(Pkg.dir("DiffEqDevTools")); Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())'

src/benchmark.jl

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,122 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
394394
WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,error_estimate,numruns_error)
395395
end
396396

397+
398+
function WorkPrecisionSet(prob::AbstractEnsembleProblem,abstols,reltols,setups,test_dt=nothing;
399+
numruns=5, trajectories=1000,
400+
print_names=false,names=nothing,appxsol_setup=nothing,expected_value=nothing,
401+
error_estimate=:weak_final,ensemblealg=EnsembleThreads(),
402+
kwargs...)
403+
404+
@assert names === nothing || length(setups) == length(names)
405+
406+
weak_timeseries_errors = error_estimate WEAK_TIMESERIES_ERRORS
407+
weak_dense_errors = error_estimate WEAK_DENSE_ERRORS
408+
409+
N = length(setups); M = length(abstols)
410+
times = Array{Float64}(undef,M,N)
411+
solutions = Array{Any}(undef,M,N)
412+
if names === nothing
413+
names = [string(nameof(typeof(setup[:alg]))) for setup in setups]
414+
end
415+
time_tmp = Vector{Float64}(undef,numruns)
416+
417+
# First calculate all of the errors
418+
for k in 1:N
419+
for j in 1:M
420+
if !haskey(setups[1],:dts)
421+
sol = solve(prob,setups[k][:alg],ensemblealg;
422+
setups[k]...,
423+
abstol=abstols[j],
424+
reltol=reltols[j],
425+
timeseries_errors=false,
426+
dense_errors = false,
427+
trajectories=Int(trajectories),kwargs...)
428+
else
429+
sol = solve(prob,setups[k][:alg],ensemblealg;
430+
setups[k]...,
431+
abstol=abstols[j],
432+
reltol=reltols[j],
433+
dt=setups[k][:dts][j],
434+
timeseries_errors=false,
435+
dense_errors = false,
436+
trajectories=Int(trajectories),kwargs...)
437+
end
438+
solutions[j,k] = sol
439+
end
440+
@info "$(setups[k][:alg]) ($k/$N)"
441+
end
442+
443+
if error_estimate WEAK_ERRORS
444+
if expected_value != nothing
445+
errors = [[LinearAlgebra.norm(Statistics.mean(solutions[i,j].u .- expected_value))
446+
for i in 1:M] for j in 1:N]
447+
else
448+
sol = solve(prob,appxsol_setup[:alg],ensemblealg;kwargs...,appxsol_setup...,
449+
timeseries_errors=false,dense_errors = false,trajectories=Int(trajectories))
450+
errors = [[LinearAlgebra.norm(Statistics.mean(solutions[i,j].u .- sol.u))
451+
for i in 1:M] for j in 1:N]
452+
end
453+
else
454+
error("use RODEProblem instead of EnsembleProblem for strong errors.")
455+
end
456+
457+
local _sol
458+
459+
# Now time it
460+
for k in 1:N
461+
# precompile
462+
GC.gc()
463+
if !haskey(setups[1],:dts)
464+
_sol = solve(prob,setups[k][:alg],ensemblealg;
465+
setups[k]...,
466+
abstol=abstols[1],
467+
reltol=reltols[1],
468+
timeseries_errors=false,
469+
dense_errors = false,
470+
trajectories=Int(trajectories),kwargs...)
471+
else
472+
_sol = solve(prob,setups[k][:alg],ensemblealg;
473+
setups[k]...,
474+
abstol=abstols[1],
475+
reltol=reltols[1],
476+
dt=setups[k][:dts][1],
477+
timeseries_errors=false,
478+
dense_errors = false,
479+
trajectories=Int(trajectories),kwargs...)
480+
end
481+
#x = isempty(_sol.t) ? 0 : round(Int,mean(_sol.t) - sum(_sol.t)/length(_sol.t))
482+
GC.gc()
483+
for j in 1:M
484+
for i in 1:numruns
485+
time_tmp[i] = @elapsed if !haskey(setups[k],:dts)
486+
sol = solve(prob,setups[k][:alg],ensemblealg;
487+
setups[k]...,
488+
abstol=abstols[j],
489+
reltol=reltols[j],
490+
timeseries_errors=false,
491+
dense_errors = false,
492+
trajectories=Int(trajectories),kwargs...)
493+
else
494+
sol = solve(prob,setups[k][:alg],ensemblealg;
495+
setups[k]...,
496+
abstol=abstols[j],
497+
reltol=reltols[j],
498+
dt=setups[k][:dts][j],
499+
timeseries_errors=false,
500+
dense_errors = false,
501+
trajectories=Int(trajectories),kwargs...)
502+
end
503+
end
504+
times[j,k] = mean(time_tmp) #+ x
505+
GC.gc()
506+
end
507+
end
508+
509+
wps = [WorkPrecision(prob,abstols,reltols,errors[i],times[:,i],names[i],N) for i in 1:N]
510+
WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,error_estimate,Int(trajectories))
511+
end
512+
397513
function get_sample_errors(prob::AbstractRODEProblem,setup,test_dt=nothing;
398514
appxsol_setup=nothing,
399515
numruns,error_estimate=:final,

test/analyticless_convergence_tests.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,44 @@ sim2 = analyticless_test_convergence(dts,prob,SRIW1(),test_dt,trajectories=100,
4040
@test abs(sim2.𝒪est[:final]-1.5) < 0.3
4141
@show sim2.𝒪est[:final]
4242

43+
44+
# EnsembleProblem
45+
46+
function prob_func(prob, i, repeat)
47+
remake(prob,seed=seeds[i])
48+
end
49+
50+
u₀ = [1.0,1.0]
51+
function f2!(du,u,p,t)
52+
du[1] = -273//512*u[1]
53+
du[2] = -1//160*u[1]-(-785//512+sqrt(2)/8)*u[2]
54+
end
55+
function g2!(du,u,p,t)
56+
du[1,1] = 1//4*u[1]
57+
du[1,2] = 1//16*u[1]
58+
du[2,1] = (1-2*sqrt(2))/4*u[1]
59+
du[2,2] = 1//10*u[1]+1//16*u[2]
60+
end
61+
dts = 1 .//2 .^(3:-1:0)
62+
tspan = (0.0,3.0)
63+
64+
h2(z) = z^2 # but apply it only to u[1]
65+
66+
prob = SDEProblem(f2!,g2!,u₀,tspan,noise_rate_prototype=zeros(2,2))
67+
68+
numtraj = Int(1e5)
69+
seed = 100
70+
Random.seed!(seed)
71+
seeds = rand(UInt, numtraj)
72+
ensemble_prob = EnsembleProblem(prob;
73+
output_func = (sol,i) -> (h2(sol[end][1]),false),
74+
prob_func = prob_func
75+
)
76+
sim = test_convergence(dts,ensemble_prob,DRI1(),save_everystep=false,trajectories=numtraj,save_start=false,adaptive=false,weak_timeseries_errors=false,weak_dense_errors=false,expected_value=exp(-3.0))
77+
78+
@test abs(sim.𝒪est[:weak_final]-2.0) < 0.3
79+
@show sim.𝒪est[:weak_final]
80+
4381
### SDDE
4482

4583
function hayes_modelf(du,u,h,p,t)

test/analyticless_stochastic_wp.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,68 @@ se2 = get_sample_errors(prob2,setups[1],test_dt,appxsol_setup = appxsol_setup,
3737
numruns=[5,10,25,50,100],solution_runs=20)
3838

3939
@test all(se[1:5]-se2 .< 1e-1)
40+
41+
42+
# Ensemble Problem with non-commutative noise process
43+
44+
function prob_func(prob, i, repeat)
45+
remake(prob,seed=seeds[i])
46+
end
47+
48+
u₀ = [1.0,1.0]
49+
function f2!(du,u,p,t)
50+
du[1] = -273//512*u[1]
51+
du[2] = -1//160*u[1]-(-785//512+sqrt(2)/8)*u[2]
52+
return nothing
53+
end
54+
function g2!(du,u,p,t)
55+
du[1,1] = 1//4*u[1]
56+
du[1,2] = 1//16*u[1]
57+
du[2,1] = (1-2*sqrt(2))/4*u[1]
58+
du[2,2] = 1//10*u[1]+1//16*u[2]
59+
return nothing
60+
end
61+
dts = 1 .//2 .^(3:-1:0)
62+
tspan = (0.0,3.0)
63+
64+
h2(z) = z^2 # but apply it only to u[1]
65+
66+
prob = SDEProblem(f2!,g2!,u₀,tspan,noise_rate_prototype=zeros(2,2))
67+
68+
numtraj = Int(1e5)
69+
seed = 100
70+
Random.seed!(seed)
71+
seeds = rand(UInt, numtraj)
72+
ensemble_prob = EnsembleProblem(prob;
73+
output_func = (sol,i) -> (h2(sol[end][1]),false),
74+
prob_func = prob_func
75+
)
76+
77+
reltols = 1.0 ./ 4.0 .^ (1:4)
78+
abstols = reltols#[0.0 for i in eachindex(reltols)]
79+
setups = [
80+
Dict(:alg=>EM(),:dts=>dts),
81+
Dict(:alg=>SimplifiedEM(),:dts=>dts),
82+
Dict(:alg=>DRI1(),:dts=>dts, :adaptive=>false)
83+
]
84+
test_dt = 1//1000
85+
appxsol_setup = Dict(:alg=>EM(), :dt=>test_dt)
86+
87+
# without analytical expectation value
88+
wp1 = @time WorkPrecisionSet(ensemble_prob,abstols,reltols,setups,test_dt;
89+
maxiters = 1e7,
90+
verbose=false,save_everystep=false, save_start=false,
91+
appxsol_setup = appxsol_setup,
92+
trajectories=numtraj, error_estimate=:weak_final)
93+
94+
# with analytical expectation value
95+
wp2 = @time WorkPrecisionSet(ensemble_prob,abstols,reltols,setups,test_dt;
96+
maxiters = 1e7,
97+
verbose=false,save_everystep=false, save_start=false,
98+
appxsol_setup = appxsol_setup,expected_value=exp(-3.0),
99+
trajectories=numtraj, error_estimate=:weak_final)
100+
101+
err1 = [wp1.wps[i].errors for i=1:length(setups)]
102+
err2 = [wp2.wps[i].errors for i=1:length(setups)]
103+
104+
@test isapprox(err1, err2, atol=1e-4)

0 commit comments

Comments
 (0)