Skip to content

Commit 290ecee

Browse files
Merge pull request #74 from frankschae/test_convergence_ensemble
test_convergence() for ensemble simulations
2 parents b3a326f + 93a8db9 commit 290ecee

File tree

2 files changed

+53
-21
lines changed

2 files changed

+53
-21
lines changed

src/DiffEqDevTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import Base: length
1010

1111
import DiffEqBase: AbstractODEProblem, AbstractDDEProblem,
1212
AbstractODESolution, AbstractRODEProblem, AbstractSDEProblem,
13-
AbstractSDDEProblem,
13+
AbstractSDDEProblem, AbstractEnsembleProblem,
1414
AbstractDAEProblem, @def, ConvergenceSetup, DEAlgorithm,
1515
ODERKTableau, AbstractTimeseriesSolution, ExplicitRKTableau,
1616
ImplicitRKTableau

src/convergence.jl

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,24 @@ mutable struct ConvergenceSimulation{SolType}
88
end
99

1010
function ConvergenceSimulation(solutions,convergence_axis;
11-
auxdata=nothing,additional_errors=nothing)
11+
auxdata=nothing,additional_errors=nothing, expected_value=nothing)
1212
N = size(solutions,1)
1313
uEltype = eltype(solutions[1].u[1])
1414
errors = Dict() #Should add type information
15-
if isempty(solutions[1].errors)
16-
error("Errors dictionary is empty. No analytical solution set.")
17-
end
18-
for k in keys(solutions[1].errors)
19-
errors[k] = [mean(sol.errors[k]) for sol in solutions]
15+
if expected_value == nothing
16+
if isempty(solutions[1].errors)
17+
error("Errors dictionary is empty. No analytical solution set.")
18+
end
19+
for k in keys(solutions[1].errors)
20+
errors[k] = [mean(sol.errors[k]) for sol in solutions]
21+
end
2022
end
2123
if additional_errors != nothing
2224
for k in keys(additional_errors)
2325
errors[k] = additional_errors[k]
2426
end
2527
end
28+
2629
𝒪est = Dict((calc𝒪estimates(p) for p = pairs(errors)))
2730
#𝒪est = Dict(map(calc𝒪estimates,errors))
2831
𝒪esttmp = Dict() #Makes Dict of Any to be more compatible
@@ -35,25 +38,53 @@ function ConvergenceSimulation(solutions,convergence_axis;
3538
return(ConvergenceSimulation(solutions,errors,N,auxdata,𝒪est,convergence_axis))
3639
end
3740

38-
function test_convergence(dts::AbstractArray,prob::Union{AbstractRODEProblem,AbstractSDEProblem},
41+
function test_convergence(dts::AbstractArray,
42+
prob::Union{AbstractRODEProblem,AbstractSDEProblem,AbstractEnsembleProblem},
3943
alg,ensemblealg=EnsembleThreads();
40-
trajectories,save_everystep=true,timeseries_steps=1,
44+
trajectories,save_start=true,save_everystep=true,timeseries_steps=1,
4145
timeseries_errors=save_everystep,adaptive=false,
42-
weak_timeseries_errors=false,weak_dense_errors=false,kwargs...)
46+
weak_timeseries_errors=false,weak_dense_errors=false,
47+
expected_value=nothing,kwargs...)
4348
N = length(dts)
44-
ensemble_prob = EnsembleProblem(prob)
45-
_solutions = [solve(ensemble_prob,alg,ensemblealg;dt=dts[i],save_everystep=save_everystep,
46-
timeseries_steps=timeseries_steps,adaptive=adaptive,
47-
timeseries_errors=timeseries_errors,trajectories=trajectories,
48-
kwargs...) for i in 1:N]
49-
solutions = [DiffEqBase.calculate_ensemble_errors(sim;weak_timeseries_errors=weak_timeseries_errors,weak_dense_errors=weak_dense_errors) for sim in _solutions]
49+
50+
if typeof(prob) <: AbstractEnsembleProblem
51+
ensemble_prob = prob
52+
else
53+
ensemble_prob = EnsembleProblem(prob)
54+
end
55+
56+
_solutions = Array{Any}(undef,length(dts))
57+
for i in 1:length(dts)
58+
sol = solve(ensemble_prob,alg,ensemblealg;dt=dts[i],adaptive=adaptive,
59+
save_start=save_start,save_everystep=save_everystep,timeseries_steps=timeseries_steps,
60+
timeseries_errors=timeseries_errors,weak_timeseries_errors=weak_timeseries_errors,
61+
weak_dense_errors=weak_dense_errors,trajectories=Int(trajectories),kwargs...)
62+
@info "dt: $(dts[i]) ($i/$N)"
63+
_solutions[i] = sol
64+
end
65+
5066
auxdata = Dict("dts" => dts)
51-
# Now Calculate Weak Errors
52-
additional_errors = Dict()
53-
for k in keys(solutions[1].weak_errors)
54-
additional_errors[k] = [sol.weak_errors[k] for sol in solutions]
67+
68+
if expected_value == nothing
69+
solutions = [DiffEqBase.calculate_ensemble_errors(sim;weak_timeseries_errors=weak_timeseries_errors,weak_dense_errors=weak_dense_errors) for sim in _solutions]
70+
# Now Calculate Weak Errors
71+
additional_errors = Dict()
72+
for k in keys(solutions[1].weak_errors)
73+
additional_errors[k] = [sol.weak_errors[k] for sol in solutions]
74+
end
75+
76+
else
77+
additional_errors = Dict()
78+
additional_errors[:weak_final] = []
79+
for sol in _solutions
80+
weak_final = LinearAlgebra.norm(Statistics.mean(sol.u .- expected_value))
81+
push!(additional_errors[:weak_final],weak_final)
82+
end
83+
solutions = _solutions
5584
end
56-
ConvergenceSimulation(solutions,dts,auxdata=auxdata,additional_errors=additional_errors)
85+
86+
return ConvergenceSimulation(solutions,dts,auxdata=auxdata,additional_errors=additional_errors,
87+
expected_value=expected_value)
5788
end
5889

5990
function analyticless_test_convergence(dts::AbstractArray,
@@ -153,6 +184,7 @@ end
153184
function calc𝒪estimates(error::Pair)
154185
key = error.first
155186
error =error.second
187+
156188
if ndims(error)>1 error=mean(error,1) end
157189
S = Vector{eltype(error)}(undef, length(error)-1)
158190
for i=1:length(error)-1

0 commit comments

Comments
 (0)