@@ -155,7 +155,6 @@ mutable struct WorkPrecisionSet
155
155
prob
156
156
setups
157
157
names
158
- sample_error
159
158
error_estimate
160
159
numruns
161
160
end
@@ -257,7 +256,7 @@ function WorkPrecisionSet(prob,
257
256
name= names[i],kwargs... ,setups[i]. .. )
258
257
end
259
258
end
260
- return WorkPrecisionSet (wps,N,abstols,reltols,prob,setups,names,nothing , error_estimate,nothing )
259
+ return WorkPrecisionSet (wps,N,abstols,reltols,prob,setups,names,error_estimate,nothing )
261
260
end
262
261
263
262
@def error_calculation begin
311
310
function WorkPrecisionSet (prob:: AbstractRODEProblem ,abstols,reltols,setups,test_dt= nothing ;
312
311
numruns= 20 ,numruns_error = 20 ,
313
312
print_names= false ,names= nothing ,appxsol_setup= nothing ,
314
- error_estimate= :final ,parallel_type = :none ,kwargs... )
313
+ error_estimate= :final ,parallel_type = :none ,
314
+ kwargs... )
315
315
316
316
timeseries_errors = DiffEqBase. has_analytic (prob. f) && error_estimate ∈ TIMESERIES_ERRORS
317
317
weak_timeseries_errors = error_estimate ∈ WEAK_TIMESERIES_ERRORS
@@ -336,8 +336,7 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
336
336
@error_calculation
337
337
end
338
338
end
339
- analytical_solution_ends = [tmp_solutions[i,1 ,1 ]. u_analytic[end ] for i in 1 : numruns_error]
340
- sample_error = 1.96 std (norm .(analytical_solution_ends))/ sqrt (numruns_error)
339
+
341
340
_solutions_k = [[EnsembleSolution (tmp_solutions[:,j,k],0.0 ,true ) for j in 1 : M] for k in 1 : N]
342
341
solutions = [[DiffEqBase. calculate_ensemble_errors (sim;weak_timeseries_errors= weak_timeseries_errors,weak_dense_errors= weak_dense_errors) for sim in sol_k] for sol_k in _solutions_k]
343
342
if error_estimate ∈ WEAK_ERRORS
@@ -391,46 +390,62 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
391
390
end
392
391
393
392
wps = [WorkPrecision (prob,abstols,reltols,errors[i],times[:,i],names[i],N) for i in 1 : N]
394
- WorkPrecisionSet (wps,N,abstols,reltols,prob,setups,names,sample_error, error_estimate,numruns_error)
393
+ WorkPrecisionSet (wps,N,abstols,reltols,prob,setups,names,error_estimate,numruns_error)
395
394
end
396
395
397
- @def sample_errors begin
398
- if ! DiffEqBase. has_analytic (prob. f)
399
- true_sol = solve (prob,appxsol_setup[:alg ];kwargs... ,appxsol_setup... ,
400
- save_everystep= false )
401
- analytical_solution_ends[i] = norm (true_sol. u[end ])
402
- else
403
- _dt = prob. tspan[2 ] - prob. tspan[1 ]
404
- if typeof (prob. u0) <: Number
405
- W = sqrt (_dt)* randn ()
406
- else
407
- W = sqrt (_dt)* randn (size (prob. u0))
396
+ function get_sample_errors (prob:: AbstractRODEProblem ,setup,test_dt= nothing ;
397
+ appxsol_setup= nothing ,
398
+ numruns,error_estimate= :final ,
399
+ sample_error_runs = Int (1e7 ),
400
+ solution_runs,
401
+ parallel_type = :none ,kwargs... )
402
+
403
+ maxnumruns = findmax (numruns)[1 ]
404
+
405
+ tmp_solutions_full = map (1 : solution_runs) do i
406
+ @info " Solution Run: $i "
407
+ # Use the WorkPrecision stuff to calculate the errors
408
+ tmp_solutions = Array {Any} (undef,maxnumruns,1 ,1 )
409
+ setups = [setup]
410
+ abstols = [1e-2 ] # Standard default
411
+ reltols = [1e-2 ] # Standard default
412
+ M = 1 ; N = 1
413
+ timeseries_errors = false ; dense_errors = false
414
+ if parallel_type == :threads
415
+ Threads. @threads for i in 1 : maxnumruns
416
+ @error_calculation
417
+ end
418
+ elseif parallel_type == :none
419
+ for i in 1 : maxnumruns
420
+ @error_calculation
421
+ end
408
422
end
409
- analytical_solution_ends[i] = norm (prob . f . analytic (prob . u0,prob . p,prob . tspan[ 2 ],W) )
423
+ tmp_solutions = vec (tmp_solutions )
410
424
end
411
- end
412
425
413
- function get_sample_errors (prob:: AbstractRODEProblem ,test_dt= nothing ;
414
- appxsol_setup= nothing ,
415
- numruns= 20 ,std_estimation_runs = maximum (numruns),
416
- error_estimate= :final ,parallel_type = :none ,kwargs... )
417
- _std_estimation_runs = Int (std_estimation_runs)
418
- analytical_solution_ends = Vector {typeof(norm(prob.u0))} (undef,_std_estimation_runs)
419
- if parallel_type == :threads
420
- Threads. @threads for i in 1 : _std_estimation_runs
421
- @sample_errors
422
- end
423
- elseif parallel_type == :none
424
- for i in 1 : _std_estimation_runs
425
- @info " Standard deviation estimation: $i /$_std_estimation_runs "
426
- @sample_errors
426
+ if DiffEqBase. has_analytic (prob. f)
427
+ analytical_mean_end = mean (1 : sample_error_runs) do i
428
+ _dt = prob. tspan[2 ] - prob. tspan[1 ]
429
+ if typeof (prob. u0) <: Number
430
+ W = sqrt (_dt)* randn ()
431
+ else
432
+ W = sqrt (_dt)* randn (size (prob. u0))
433
+ end
434
+ prob. f. analytic (prob. u0,prob. p,prob. tspan[2 ],W)
427
435
end
436
+ else
437
+ # Use the mean of the means as the analytical mean
438
+ analytical_mean_end = mean (mean (tmp_solutions[i]. u[end ] for i in 1 : length (tmp_solutions)) for tmp_solutions in tmp_solutions_full)
428
439
end
429
- est_std = std (analytical_solution_ends)
430
- if typeof (numruns) <: Number
431
- return 1.96 est_std/ sqrt (numruns)
440
+
441
+ if numruns isa Number
442
+ mean_solution_ends = [mean ([tmp_solutions[i]. u[end ] for i in 1 : maxnumruns]) for tmp_solutions in tmp_solutions_full]
443
+ return sample_error = 1.96 std (norm (mean_sol_end - analytical_mean_end) for mean_sol_end in mean_solution_ends)/ sqrt (numruns)
432
444
else
433
- return [1.96 est_std/ sqrt (_numruns) for _numruns in numruns]
445
+ map (1 : length (numruns)) do i
446
+ mean_solution_ends = [mean ([tmp_solutions[i]. u[end ] for i in 1 : numruns[i]]) for tmp_solutions in tmp_solutions_full]
447
+ sample_error = 1.96 std (norm (mean_sol_end - analytical_mean_end) for mean_sol_end in mean_solution_ends)/ sqrt (numruns[i])
448
+ end
434
449
end
435
450
end
436
451
0 commit comments