@@ -36,35 +36,46 @@ function Shootout(prob,setups;appxsol=nothing,names=nothing,error_estimate=:fina
36
36
effratios = Matrix {Float64} (undef,N,N)
37
37
timeseries_errors = error_estimate ∈ TIMESERIES_ERRORS
38
38
dense_errors = error_estimate ∈ DENSE_ERRORS
39
- if names == nothing
39
+ if names === nothing
40
40
names = [string (nameof (typeof (setup[:alg ]))) for setup in setups]
41
41
end
42
42
for i in eachindex (setups)
43
43
sol = solve (prob,setups[i][:alg ];timeseries_errors= timeseries_errors,
44
- dense_errors = dense_errors,kwargs... ,setups[i]. .. ) # Compile and get result
44
+ dense_errors = dense_errors,kwargs... ,setups[i]. .. )
45
45
46
- if appxsol != nothing
47
- errsol = appxtrue (sol,appxsol)
46
+ if :prob_choice ∈ keys (setups[i])
47
+ cur_appxsol = appxsol[setups[i][:prob_choice ]]
48
+ else
49
+ cur_appxsol = appxsol
50
+ end
51
+
52
+ if cur_appxsol != cur_appxsol
53
+ errsol = appxtrue (sol,cur_appxsol)
48
54
errors[i] = errsol. errors[error_estimate]
49
55
solutions[i] = errsol
50
56
else
51
57
errors[i] = sol. errors[error_estimate]
52
58
solutions[i] = sol
53
59
end
54
60
55
- benchmark_f = let prob= prob,alg= setups[i][:alg ],sol= sol,kwargs= kwargs
56
- function benchmark_f ()
57
- @elapsed solve (prob,alg,(sol. u),(sol. t),(sol. k);
58
- timeseries_errors = false ,
59
- dense_errors = false , kwargs... )
60
- end
61
+ if haskey (setups[i], :prob_choice )
62
+ _prob = prob[setups[i][:prob_choice ]]
63
+ else
64
+ _prob = prob
65
+ end
66
+
67
+ benchmark_f = let _prob= _prob,alg= setups[i][:alg ],sol= sol,kwargs= kwargs
68
+ () -> @elapsed solve (_prob, alg, sol. u, sol. t, sol. k;
69
+ timeseries_errors = false ,
70
+ dense_errors = false , kwargs... )
61
71
end
72
+ benchmark_f () # pre-compile
62
73
63
- b_t = benchmark_f ()
74
+ b_t = benchmark_f ()
64
75
if b_t > seconds
65
76
times[i] = b_t
66
77
else
67
- times[i] = minimum ([b_t; map (i -> benchmark_f (),2 : numruns)] )
78
+ times[i] = mapreduce (i -> benchmark_f (), min, 2 : numruns; init = b_t )
68
79
end
69
80
70
81
effs[i] = 1 / (errors[i]* times[i])
@@ -86,10 +97,10 @@ function ShootoutSet(probs,setups;probaux=nothing,
86
97
N = length (probs)
87
98
shootouts = Vector {Shootout} (undef,N)
88
99
winners = Vector {String} (undef,N)
89
- if names == nothing
100
+ if names === nothing
90
101
names = [string (nameof (typeof (setup[:alg ]))) for setup in setups]
91
102
end
92
- if probaux == nothing
103
+ if probaux === nothing
93
104
probaux = Vector {Dict{Symbol,Any}} (undef,N)
94
105
for i in 1 : N
95
106
probaux[i] = Dict {Symbol,Any} ()
@@ -143,7 +154,6 @@ mutable struct WorkPrecisionSet
143
154
prob
144
155
setups
145
156
names
146
- sample_error
147
157
error_estimate
148
158
numruns
149
159
end
@@ -153,68 +163,80 @@ function WorkPrecision(prob,alg,abstols,reltols,dts=nothing;
153
163
N = length (abstols)
154
164
errors = Vector {Float64} (undef,N)
155
165
times = Vector {Float64} (undef,N)
156
- if name == nothing
166
+ if name === nothing
157
167
name = " WP-Alg"
158
168
end
159
- timeseries_errors = error_estimate ∈ TIMESERIES_ERRORS
160
- dense_errors = error_estimate ∈ DENSE_ERRORS
161
- for i in 1 : N
162
- # Calculate errors and precompile
163
- if dts == nothing
164
- sol = solve (prob,alg;kwargs... ,abstol= abstols[i],
165
- reltol= reltols[i],timeseries_errors= timeseries_errors,
166
- dense_errors = dense_errors) # Compile and get result
167
- else
168
- sol = solve (prob,alg;kwargs... ,abstol= abstols[i],
169
- reltol= reltols[i],dt= dts[i],timeseries_errors= timeseries_errors,
170
- dense_errors = dense_errors) # Compile and get result
171
- end
172
169
173
- if appxsol != nothing
174
- errsol = appxtrue (sol,appxsol)
175
- errors[i] = mean (errsol. errors[error_estimate])
176
- else
177
- errors[i] = mean (sol. errors[error_estimate])
178
- end
170
+ if haskey (kwargs, :prob_choice )
171
+ _prob = prob[kwargs[:prob_choice ]]
172
+ else
173
+ _prob = prob
174
+ end
179
175
180
- benchmark_f = let dts= dts,prob= prob,alg= alg,sol= sol,abstols= abstols,reltols= reltols,kwargs= kwargs
181
- function benchmark_f ()
182
- if dts == nothing
183
- @elapsed solve (prob,alg,(sol. u),(sol. t),(sol. k);
184
- abstol= (abstols[i]),
185
- reltol= (reltols[i]),
186
- timeseries_errors = false ,
187
- dense_errors = false , kwargs... )
176
+ let _prob = _prob
177
+ timeseries_errors = error_estimate ∈ TIMESERIES_ERRORS
178
+ dense_errors = error_estimate ∈ DENSE_ERRORS
179
+ for i in 1 : N
180
+ if dts === nothing
181
+ sol = solve (_prob,alg;kwargs... ,abstol= abstols[i],
182
+ reltol= reltols[i],timeseries_errors= timeseries_errors,
183
+ dense_errors = dense_errors)
184
+ else
185
+ sol = solve (_prob,alg;kwargs... ,abstol= abstols[i],
186
+ reltol= reltols[i],dt= dts[i],timeseries_errors= timeseries_errors,
187
+ dense_errors = dense_errors)
188
+ end
189
+
190
+ if haskey (kwargs, :prob_choice )
191
+ cur_appxsol = appxsol[kwargs[:prob_choice ]]
192
+ else
193
+ cur_appxsol = appxsol
194
+ end
195
+
196
+ if cur_appxsol != = nothing
197
+ errsol = appxtrue (sol,cur_appxsol)
198
+ errors[i] = mean (errsol. errors[error_estimate])
199
+ else
200
+ errors[i] = mean (sol. errors[error_estimate])
201
+ end
202
+
203
+ benchmark_f = let dts= dts,_prob= _prob,alg= alg,sol= sol,abstols= abstols,reltols= reltols,kwargs= kwargs
204
+ if dts === nothing
205
+ () -> @elapsed solve (_prob, alg, sol. u, sol. t, sol. k;
206
+ abstol = abstols[i],
207
+ reltol = reltols[i],
208
+ timeseries_errors = false ,
209
+ dense_errors = false , kwargs... )
188
210
else
189
- @elapsed solve (prob, alg,( sol. u),( sol. t),( sol. k) ;
190
- abstol= ( abstols[i]) ,
191
- reltol= ( reltols[i]) ,
192
- dt = ( dts[i]) ,
193
- timeseries_errors = false ,
194
- dense_errors = false , kwargs... )
211
+ () -> @elapsed solve (_prob, alg, sol. u, sol. t, sol. k;
212
+ abstol = abstols[i],
213
+ reltol = reltols[i],
214
+ dt = dts[i],
215
+ timeseries_errors = false ,
216
+ dense_errors = false , kwargs... )
195
217
end
196
218
end
197
- end
219
+ benchmark_f () # pre-compile
198
220
199
- b_t = benchmark_f ()
200
- if b_t > seconds
201
- times[i] = b_t
202
- else
203
- times[i] = minimum ([b_t;map (i-> benchmark_f (),2 : numruns)])
221
+ b_t = benchmark_f ()
222
+ if b_t > seconds
223
+ times[i] = b_t
224
+ else
225
+ times[i] = mapreduce (i -> benchmark_f (), min, 2 : numruns; init = b_t)
226
+ end
204
227
end
205
228
end
206
229
return WorkPrecision (prob,abstols,reltols,errors,times,name,N)
207
230
end
208
231
209
- function WorkPrecisionSet (prob:: Union {AbstractODEProblem,AbstractDDEProblem,
210
- AbstractDAEProblem},
232
+ function WorkPrecisionSet (prob,
211
233
abstols,reltols,setups;
212
234
print_names= false ,names= nothing ,appxsol= nothing ,
213
235
error_estimate= :final ,
214
236
test_dt= nothing ,kwargs... )
215
237
N = length (setups)
216
238
wps = Vector {WorkPrecision} (undef,N)
217
- if names == nothing
239
+ if names === nothing
218
240
names = [string (nameof (typeof (setup[:alg ]))) for setup in setups]
219
241
end
220
242
for i in 1 : N
@@ -231,7 +253,7 @@ function WorkPrecisionSet(prob::Union{AbstractODEProblem,AbstractDDEProblem,
231
253
name= names[i],kwargs... ,setups[i]. .. )
232
254
end
233
255
end
234
- return WorkPrecisionSet (wps,N,abstols,reltols,prob,setups,names,nothing , error_estimate,nothing )
256
+ return WorkPrecisionSet (wps,N,abstols,reltols,prob,setups,names,error_estimate,nothing )
235
257
end
236
258
237
259
@def error_calculation begin
285
307
function WorkPrecisionSet (prob:: AbstractRODEProblem ,abstols,reltols,setups,test_dt= nothing ;
286
308
numruns= 20 ,numruns_error = 20 ,
287
309
print_names= false ,names= nothing ,appxsol_setup= nothing ,
288
- error_estimate= :final ,parallel_type = :none ,kwargs... )
310
+ error_estimate= :final ,parallel_type = :none ,
311
+ kwargs... )
289
312
290
313
timeseries_errors = DiffEqBase. has_analytic (prob. f) && error_estimate ∈ TIMESERIES_ERRORS
291
314
weak_timeseries_errors = error_estimate ∈ WEAK_TIMESERIES_ERRORS
@@ -294,7 +317,7 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
294
317
N = length (setups); M = length (abstols)
295
318
times = Array {Float64} (undef,M,N)
296
319
tmp_solutions = Array {Any} (undef,numruns_error,M,N)
297
- if names == nothing
320
+ if names === nothing
298
321
names = [string (nameof (typeof (setup[:alg ]))) for setup in setups]
299
322
end
300
323
time_tmp = Vector {Float64} (undef,numruns)
@@ -310,8 +333,7 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
310
333
@error_calculation
311
334
end
312
335
end
313
- analytical_solution_ends = [tmp_solutions[i,1 ,1 ]. u_analytic[end ] for i in 1 : numruns_error]
314
- sample_error = 1.96 std (norm .(analytical_solution_ends))/ sqrt (numruns_error)
336
+
315
337
_solutions_k = [[EnsembleSolution (tmp_solutions[:,j,k],0.0 ,true ) for j in 1 : M] for k in 1 : N]
316
338
solutions = [[DiffEqBase. calculate_ensemble_errors (sim;weak_timeseries_errors= weak_timeseries_errors,weak_dense_errors= weak_dense_errors) for sim in sol_k] for sol_k in _solutions_k]
317
339
if error_estimate ∈ WEAK_ERRORS
@@ -365,46 +387,62 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
365
387
end
366
388
367
389
wps = [WorkPrecision (prob,abstols,reltols,errors[i],times[:,i],names[i],N) for i in 1 : N]
368
- WorkPrecisionSet (wps,N,abstols,reltols,prob,setups,names,sample_error, error_estimate,numruns_error)
390
+ WorkPrecisionSet (wps,N,abstols,reltols,prob,setups,names,error_estimate,numruns_error)
369
391
end
370
392
371
- @def sample_errors begin
372
- if ! DiffEqBase. has_analytic (prob. f)
373
- true_sol = solve (prob,appxsol_setup[:alg ];kwargs... ,appxsol_setup... ,
374
- save_everystep= false )
375
- analytical_solution_ends[i] = norm (true_sol. u[end ])
376
- else
377
- _dt = prob. tspan[2 ] - prob. tspan[1 ]
378
- if typeof (prob. u0) <: Number
379
- W = sqrt (_dt)* randn ()
380
- else
381
- W = sqrt (_dt)* randn (size (prob. u0))
393
+ function get_sample_errors (prob:: AbstractRODEProblem ,setup,test_dt= nothing ;
394
+ appxsol_setup= nothing ,
395
+ numruns,error_estimate= :final ,
396
+ sample_error_runs = Int (1e7 ),
397
+ solution_runs,
398
+ parallel_type = :none ,kwargs... )
399
+
400
+ maxnumruns = findmax (numruns)[1 ]
401
+
402
+ tmp_solutions_full = map (1 : solution_runs) do i
403
+ @info " Solution Run: $i "
404
+ # Use the WorkPrecision stuff to calculate the errors
405
+ tmp_solutions = Array {Any} (undef,maxnumruns,1 ,1 )
406
+ setups = [setup]
407
+ abstols = [1e-2 ] # Standard default
408
+ reltols = [1e-2 ] # Standard default
409
+ M = 1 ; N = 1
410
+ timeseries_errors = false ; dense_errors = false
411
+ if parallel_type == :threads
412
+ Threads. @threads for i in 1 : maxnumruns
413
+ @error_calculation
414
+ end
415
+ elseif parallel_type == :none
416
+ for i in 1 : maxnumruns
417
+ @error_calculation
418
+ end
382
419
end
383
- analytical_solution_ends[i] = norm (prob . f . analytic (prob . u0,prob . p,prob . tspan[ 2 ],W) )
420
+ tmp_solutions = vec (tmp_solutions )
384
421
end
385
- end
386
422
387
- function get_sample_errors (prob:: AbstractRODEProblem ,test_dt= nothing ;
388
- appxsol_setup= nothing ,
389
- numruns= 20 ,std_estimation_runs = maximum (numruns),
390
- error_estimate= :final ,parallel_type = :none ,kwargs... )
391
- _std_estimation_runs = Int (std_estimation_runs)
392
- analytical_solution_ends = Vector {typeof(norm(prob.u0))} (undef,_std_estimation_runs)
393
- if parallel_type == :threads
394
- Threads. @threads for i in 1 : _std_estimation_runs
395
- @sample_errors
396
- end
397
- elseif parallel_type == :none
398
- for i in 1 : _std_estimation_runs
399
- @info " Standard deviation estimation: $i /$_std_estimation_runs "
400
- @sample_errors
423
+ if DiffEqBase. has_analytic (prob. f)
424
+ analytical_mean_end = mean (1 : sample_error_runs) do i
425
+ _dt = prob. tspan[2 ] - prob. tspan[1 ]
426
+ if typeof (prob. u0) <: Number
427
+ W = sqrt (_dt)* randn ()
428
+ else
429
+ W = sqrt (_dt)* randn (size (prob. u0))
430
+ end
431
+ prob. f. analytic (prob. u0,prob. p,prob. tspan[2 ],W)
401
432
end
433
+ else
434
+ # Use the mean of the means as the analytical mean
435
+ analytical_mean_end = mean (mean (tmp_solutions[i]. u[end ] for i in 1 : length (tmp_solutions)) for tmp_solutions in tmp_solutions_full)
402
436
end
403
- est_std = std (analytical_solution_ends)
404
- if typeof (numruns) <: Number
405
- return 1.96 est_std/ sqrt (numruns)
437
+
438
+ if numruns isa Number
439
+ mean_solution_ends = [mean ([tmp_solutions[i]. u[end ] for i in 1 : maxnumruns]) for tmp_solutions in tmp_solutions_full]
440
+ return sample_error = 1.96 std (norm (mean_sol_end - analytical_mean_end) for mean_sol_end in mean_solution_ends)/ sqrt (numruns)
406
441
else
407
- return [1.96 est_std/ sqrt (_numruns) for _numruns in numruns]
442
+ map (1 : length (numruns)) do i
443
+ mean_solution_ends = [mean ([tmp_solutions[i]. u[end ] for i in 1 : numruns[i]]) for tmp_solutions in tmp_solutions_full]
444
+ sample_error = 1.96 std (norm (mean_sol_end - analytical_mean_end) for mean_sol_end in mean_solution_ends)/ sqrt (numruns[i])
445
+ end
408
446
end
409
447
end
410
448
0 commit comments