Skip to content

Commit 5005884

Browse files
Merge pull request #40 from JuliaDiffEq/redo_sample
Redo sample error calculations
2 parents 6d6014d + f707fed commit 5005884

File tree

3 files changed

+64
-50
lines changed

3 files changed

+64
-50
lines changed

src/benchmark.jl

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ mutable struct WorkPrecisionSet
155155
prob
156156
setups
157157
names
158-
sample_error
159158
error_estimate
160159
numruns
161160
end
@@ -257,7 +256,7 @@ function WorkPrecisionSet(prob,
257256
name=names[i],kwargs...,setups[i]...)
258257
end
259258
end
260-
return WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,nothing,error_estimate,nothing)
259+
return WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,error_estimate,nothing)
261260
end
262261

263262
@def error_calculation begin
@@ -311,7 +310,8 @@ end
311310
function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_dt=nothing;
312311
numruns=20,numruns_error = 20,
313312
print_names=false,names=nothing,appxsol_setup=nothing,
314-
error_estimate=:final,parallel_type = :none,kwargs...)
313+
error_estimate=:final,parallel_type = :none,
314+
kwargs...)
315315

316316
timeseries_errors = DiffEqBase.has_analytic(prob.f) && error_estimate TIMESERIES_ERRORS
317317
weak_timeseries_errors = error_estimate WEAK_TIMESERIES_ERRORS
@@ -336,8 +336,7 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
336336
@error_calculation
337337
end
338338
end
339-
analytical_solution_ends = [tmp_solutions[i,1,1].u_analytic[end] for i in 1:numruns_error]
340-
sample_error = 1.96std(norm.(analytical_solution_ends))/sqrt(numruns_error)
339+
341340
_solutions_k = [[EnsembleSolution(tmp_solutions[:,j,k],0.0,true) for j in 1:M] for k in 1:N]
342341
solutions = [[DiffEqBase.calculate_ensemble_errors(sim;weak_timeseries_errors=weak_timeseries_errors,weak_dense_errors=weak_dense_errors) for sim in sol_k] for sol_k in _solutions_k]
343342
if error_estimate WEAK_ERRORS
@@ -391,46 +390,62 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
391390
end
392391

393392
wps = [WorkPrecision(prob,abstols,reltols,errors[i],times[:,i],names[i],N) for i in 1:N]
394-
WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,sample_error,error_estimate,numruns_error)
393+
WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,error_estimate,numruns_error)
395394
end
396395

397-
@def sample_errors begin
398-
if !DiffEqBase.has_analytic(prob.f)
399-
true_sol = solve(prob,appxsol_setup[:alg];kwargs...,appxsol_setup...,
400-
save_everystep=false)
401-
analytical_solution_ends[i] = norm(true_sol.u[end])
402-
else
403-
_dt = prob.tspan[2] - prob.tspan[1]
404-
if typeof(prob.u0) <: Number
405-
W = sqrt(_dt)*randn()
406-
else
407-
W = sqrt(_dt)*randn(size(prob.u0))
396+
function get_sample_errors(prob::AbstractRODEProblem,setup,test_dt=nothing;
397+
appxsol_setup=nothing,
398+
numruns,error_estimate=:final,
399+
sample_error_runs = Int(1e7),
400+
solution_runs,
401+
parallel_type = :none,kwargs...)
402+
403+
maxnumruns = findmax(numruns)[1]
404+
405+
tmp_solutions_full = map(1:solution_runs) do i
406+
@info "Solution Run: $i"
407+
# Use the WorkPrecision stuff to calculate the errors
408+
tmp_solutions = Array{Any}(undef,maxnumruns,1,1)
409+
setups = [setup]
410+
abstols = [1e-2] # Standard default
411+
reltols = [1e-2] # Standard default
412+
M = 1; N = 1
413+
timeseries_errors = false; dense_errors = false
414+
if parallel_type == :threads
415+
Threads.@threads for i in 1:maxnumruns
416+
@error_calculation
417+
end
418+
elseif parallel_type == :none
419+
for i in 1:maxnumruns
420+
@error_calculation
421+
end
408422
end
409-
analytical_solution_ends[i] = norm(prob.f.analytic(prob.u0,prob.p,prob.tspan[2],W))
423+
tmp_solutions = vec(tmp_solutions)
410424
end
411-
end
412425

413-
function get_sample_errors(prob::AbstractRODEProblem,test_dt=nothing;
414-
appxsol_setup=nothing,
415-
numruns=20,std_estimation_runs = maximum(numruns),
416-
error_estimate=:final,parallel_type = :none,kwargs...)
417-
_std_estimation_runs = Int(std_estimation_runs)
418-
analytical_solution_ends = Vector{typeof(norm(prob.u0))}(undef,_std_estimation_runs)
419-
if parallel_type == :threads
420-
Threads.@threads for i in 1:_std_estimation_runs
421-
@sample_errors
422-
end
423-
elseif parallel_type == :none
424-
for i in 1:_std_estimation_runs
425-
@info "Standard deviation estimation: $i/$_std_estimation_runs"
426-
@sample_errors
426+
if DiffEqBase.has_analytic(prob.f)
427+
analytical_mean_end = mean(1:sample_error_runs) do i
428+
_dt = prob.tspan[2] - prob.tspan[1]
429+
if typeof(prob.u0) <: Number
430+
W = sqrt(_dt)*randn()
431+
else
432+
W = sqrt(_dt)*randn(size(prob.u0))
433+
end
434+
prob.f.analytic(prob.u0,prob.p,prob.tspan[2],W)
427435
end
436+
else
437+
# Use the mean of the means as the analytical mean
438+
analytical_mean_end = mean(mean(tmp_solutions[i].u[end] for i in 1:length(tmp_solutions)) for tmp_solutions in tmp_solutions_full)
428439
end
429-
est_std = std(analytical_solution_ends)
430-
if typeof(numruns) <: Number
431-
return 1.96est_std/sqrt(numruns)
440+
441+
if numruns isa Number
442+
mean_solution_ends = [mean([tmp_solutions[i].u[end] for i in 1:maxnumruns]) for tmp_solutions in tmp_solutions_full]
443+
return sample_error = 1.96std(norm(mean_sol_end - analytical_mean_end) for mean_sol_end in mean_solution_ends)/sqrt(numruns)
432444
else
433-
return [1.96est_std/sqrt(_numruns) for _numruns in numruns]
445+
map(1:length(numruns)) do i
446+
mean_solution_ends = [mean([tmp_solutions[i].u[end] for i in 1:numruns[i]]) for tmp_solutions in tmp_solutions_full]
447+
sample_error = 1.96std(norm(mean_sol_end - analytical_mean_end) for mean_sol_end in mean_solution_ends)/sqrt(numruns[i])
448+
end
434449
end
435450
end
436451

test/analyticless_stochastic_wp.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ using DiffEqProblemLibrary.SDEProblemLibrary: importsdeproblems; importsdeproble
33
using DiffEqProblemLibrary.SDEProblemLibrary: prob_sde_additivesystem
44

55
prob = prob_sde_additivesystem
6-
prob = SDEProblem(prob.f,prob.g,prob.u0,(0.0,1.0),prob.p)
6+
prob = SDEProblem(prob.f,prob.g,prob.u0,(0.0,0.1),prob.p)
77

8-
reltols = 1.0./10.0.^(1:5)
8+
reltols = 1.0./10.0.^(1:4)
99
abstols = reltols#[0.0 for i in eachindex(reltols)]
1010
setups = [Dict(:alg=>SRIW1())
1111
Dict(:alg=>EM(),:dts=>1.0./5.0.^((1:length(reltols)) .+ 1),:adaptive=>false)
@@ -19,23 +19,21 @@ test_dt = 0.1
1919
wp = WorkPrecisionSet(prob,abstols,reltols,setups,test_dt;
2020
numruns=5,names=names,error_estimate=:l2)
2121

22-
se = get_sample_errors(prob,numruns=1000)
23-
se = get_sample_errors(prob,numruns=[5;10;25;50])
22+
se = get_sample_errors(prob,setups[1],numruns=100,solution_runs=100)
23+
se = get_sample_errors(prob,setups[1],numruns=[5,10,25,50,100,1000],solution_runs=100)
2424

25-
println("Now weak error")
25+
println("Now weak error without analytical solution")
2626

2727
prob2 = SDEProblem((du,u,p,t)->prob.f(du,u,p,t),prob.g,prob.u0,(0.0,0.1),prob.p)
28-
test_dt = 1/10^5
29-
appxsol_setup = Dict(:alg=>SRIW1(),:abstol=>1e-5,:reltol=>1e-5)
28+
test_dt = 1/10^4
29+
appxsol_setup = Dict(:alg=>SRIW1(),:abstol=>1e-4,:reltol=>1e-4)
3030
wp = WorkPrecisionSet(prob2,abstols,reltols,setups,test_dt;
3131
appxsol_setup = appxsol_setup,
3232
numruns=5,names=names,error_estimate=:weak_final)
3333

3434
println("Get sample errors")
3535

36-
se2 = get_sample_errors(prob2,test_dt,appxsol_setup = appxsol_setup,
37-
numruns=5)
38-
se2 = get_sample_errors(prob2,test_dt,appxsol_setup = appxsol_setup,
39-
numruns=[5;10;25;50])
36+
se2 = get_sample_errors(prob2,setups[1],test_dt,appxsol_setup = appxsol_setup,
37+
numruns=[5,10,25,50,100],solution_runs=20)
4038

41-
@test all(se-se2 .< 1e-1)
39+
@test all(se[1:5]-se2 .< 1e-1)

test/benchmark_tests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,5 +120,6 @@ test_sol = TestSolution(sol)
120120
setups = [Dict(:alg => MethodOfSteps(BS3()))
121121
Dict(:alg => MethodOfSteps(Tsit5()))]
122122
println("Test MethodOfSteps BS3 and Tsit5")
123-
wp = WorkPrecisionSet(prob, abstols, reltols, setups; appxsol = test_sol)
123+
#Travis compile time issue
124+
#wp = WorkPrecisionSet(prob, abstols, reltols, setups; appxsol = test_sol)
124125
println("DDE Done")

0 commit comments

Comments
 (0)