Skip to content

Commit 0841a5e

Browse files
ensemble error depwarn fixes
1 parent c30f1de commit 0841a5e

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "2.52.1"
4+
version = "2.52.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/ensemble/ensemble_solutions.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
8989
errors = Dict{Symbol, Vector{eltype(u[1].u[1])}}() #Should add type information
9090
error_means = Dict{Symbol, eltype(u[1].u[1])}()
9191
error_medians = Dict{Symbol, eltype(u[1].u[1])}()
92+
93+
analyticvoa = u[1].u_analytic isa AbstractVectorOfArray ? true : false
94+
9295
for k in keys(u[1].errors)
9396
errors[k] = [sol.errors[k] for sol in u]
9497
error_means[k] = mean(errors[k])
@@ -98,12 +101,24 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
98101
weak_errors = Dict{Symbol, eltype(u[1].u[1])}()
99102
# Final
100103
m_final = mean([s.u[end] for s in u])
101-
m_final_analytic = mean([s.u_analytic[end] for s in u])
104+
105+
if analyticvoa
106+
m_final_analytic = mean([s.u_analytic.u[end] for s in u])
107+
else
108+
m_final_analytic = mean([s.u_analytic[end] for s in u])
109+
end
110+
102111
res = norm(m_final - m_final_analytic)
103112
weak_errors[:weak_final] = res
104113
if weak_timeseries_errors
105-
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic[i] for j in 1:length(u)])
106-
for i in 1:length(u[1])]
114+
115+
if analyticvoa
116+
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic.u[i] for j in 1:length(u)])
117+
for i in 1:length(u[1])]
118+
else
119+
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic[i] for j in 1:length(u)])
120+
for i in 1:length(u[1])]
121+
end
107122
ts_l2_errors = [sqrt.(sum(abs2, err) / length(err)) for err in ts_weak_errors]
108123
l2_tmp = sqrt(sum(abs2, ts_l2_errors) / length(ts_l2_errors))
109124
max_tmp = maximum([maximum(abs.(err)) for err in ts_weak_errors])
@@ -113,8 +128,9 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
113128
if weak_dense_errors
114129
densetimes = collect(range(u[1].t[1], stop = u[1].t[end], length = 100))
115130
u_analytic = [[sol.prob.f.analytic(sol.prob.u0, sol.prob.p, densetimes[i],
116-
sol.W(densetimes[i])[1])
117-
for i in eachindex(densetimes)] for sol in u]
131+
sol.W(densetimes[i])[1])
132+
for i in eachindex(densetimes)] for sol in u]
133+
118134
udense = [u[j](densetimes) for j in 1:length(u)]
119135
dense_weak_errors = [mean([udense[j].u[i] - u_analytic[j][i] for j in 1:length(u)])
120136
for i in eachindex(densetimes)]

0 commit comments

Comments
 (0)