Skip to content

Commit b69b4ca

Browse files
Merge pull request #2184 from oscardssmith/os/default_solver-v2
Redesign default ODE solver to be type-grounded and lazy
2 parents 6d35d93 + d7193fb commit b69b4ca

14 files changed

+548
-180
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1111
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1212
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
13+
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1314
ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
1415
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1516
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
@@ -75,7 +76,6 @@ RecursiveArrayTools = "2.36, 3"
7576
Reexport = "1.0"
7677
SciMLBase = "2.27.1"
7778
SciMLOperators = "0.3"
78-
SciMLStructures = "1"
7979
SimpleNonlinearSolve = "1"
8080
SimpleUnPack = "1"
8181
SparseArrays = "1.9"

src/OrdinaryDiffEq.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ using LinearSolve, SimpleNonlinearSolve
2626

2727
using LineSearches
2828

29+
import EnumX
30+
2931
import FillArrays: Trues
3032

3133
# Interfaces
@@ -141,6 +143,7 @@ include("nlsolve/functional.jl")
141143
include("nlsolve/newton.jl")
142144

143145
include("generic_rosenbrock.jl")
146+
include("composite_algs.jl")
144147

145148
include("caches/basic_caches.jl")
146149
include("caches/low_order_rk_caches.jl")
@@ -234,7 +237,6 @@ include("constants.jl")
234237
include("solve.jl")
235238
include("initdt.jl")
236239
include("interp_func.jl")
237-
include("composite_algs.jl")
238240

239241
import PrecompileTools
240242

@@ -253,9 +255,14 @@ PrecompileTools.@compile_workload begin
253255
Tsit5(), Vern7()
254256
]
255257

256-
stiff = [Rosenbrock23(), Rosenbrock23(autodiff = false),
257-
Rodas5P(), Rodas5P(autodiff = false),
258-
FBDF(), FBDF(autodiff = false)
258+
stiff = [Rosenbrock23(),
259+
Rodas5P(),
260+
FBDF()
261+
]
262+
263+
default_ode = [
264+
DefaultODEAlgorithm(autodiff=false),
265+
DefaultODEAlgorithm()
259266
]
260267

261268
autoswitch = [
@@ -284,7 +291,11 @@ PrecompileTools.@compile_workload begin
284291
append!(solver_list, stiff)
285292
end
286293

287-
if Preferences.@load_preference("PrecompileAutoSwitch", true)
294+
if Preferences.@load_preference("PrecompileDefault", true)
295+
append!(solver_list, default_ode)
296+
end
297+
298+
if Preferences.@load_preference("PrecompileAutoSwitch", false)
288299
append!(solver_list, autoswitch)
289300
end
290301

src/alg_utils.jl

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ isimplicit(alg::CompositeAlgorithm) = any(isimplicit.(alg.algs))
172172

173173
isdtchangeable(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true
174174
isdtchangeable(alg::CompositeAlgorithm) = all(isdtchangeable.(alg.algs))
175+
175176
function isdtchangeable(alg::Union{LawsonEuler, NorsettEuler, LieEuler, MagnusGauss4,
176177
CayleyEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4, ETD2})
177178
false
@@ -205,31 +206,35 @@ qmax_default(alg::CompositeAlgorithm) = minimum(qmax_default.(alg.algs))
205206
qmax_default(alg::DP8) = 6
206207
qmax_default(alg::Union{RadauIIA3, RadauIIA5}) = 8
207208

209+
function has_chunksize(alg::OrdinaryDiffEqAlgorithm)
210+
return alg isa Union{OrdinaryDiffEqExponentialAlgorithm,
211+
OrdinaryDiffEqAdaptiveExponentialAlgorithm,
212+
OrdinaryDiffEqImplicitAlgorithm,
213+
OrdinaryDiffEqAdaptiveImplicitAlgorithm,
214+
DAEAlgorithm,
215+
CompositeAlgorithm}
216+
end
208217
function get_chunksize(alg::OrdinaryDiffEqAlgorithm)
209218
error("This algorithm does not have a chunk size defined.")
210219
end
211-
get_chunksize(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS)
212-
get_chunksize(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val(CS)
213-
get_chunksize(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = Val(CS)
214-
function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
215-
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where {
216-
CS,
217-
AD
218-
}
220+
function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS},
221+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS},
222+
OrdinaryDiffEqImplicitAlgorithm{CS},
223+
OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS},
224+
DAEAlgorithm{CS},
225+
CompositeAlgorithm{CS}}) where {CS}
219226
Val(CS)
220227
end
221228

222229
function get_chunksize_int(alg::OrdinaryDiffEqAlgorithm)
223230
error("This algorithm does not have a chunk size defined.")
224231
end
225-
get_chunksize_int(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = CS
226-
get_chunksize_int(alg::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = CS
227-
get_chunksize_int(alg::DAEAlgorithm{CS, AD}) where {CS, AD} = CS
228-
function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
229-
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}}) where {
230-
CS,
231-
AD
232-
}
232+
function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS},
233+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS},
234+
OrdinaryDiffEqImplicitAlgorithm{CS},
235+
OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS},
236+
DAEAlgorithm{CS},
237+
CompositeAlgorithm{CS}}) where {CS}
233238
CS
234239
end
235240
# get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg])
@@ -965,10 +970,12 @@ alg_can_repeat_jac(alg::OrdinaryDiffEqNewtonAdaptiveAlgorithm) = true
965970
alg_can_repeat_jac(alg::IRKC) = false
966971

967972
function unwrap_alg(alg::SciMLBase.DEAlgorithm, is_stiff)
968-
iscomp = alg isa CompositeAlgorithm
969-
if !iscomp
973+
if !(alg isa CompositeAlgorithm)
970974
return alg
971975
elseif alg.choice_function isa AutoSwitchCache
976+
if length(alg.algs) > 2
977+
return alg.algs[alg.choice_function.current]
978+
end
972979
if is_stiff === nothing
973980
throwautoswitch(alg)
974981
end
@@ -985,18 +992,21 @@ end
985992

986993
function unwrap_alg(integrator, is_stiff)
987994
alg = integrator.alg
988-
iscomp = alg isa CompositeAlgorithm
989-
if !iscomp
995+
if !(alg isa CompositeAlgorithm)
990996
return alg
991997
elseif alg.choice_function isa AutoSwitchCache
992-
if is_stiff === nothing
993-
throwautoswitch(alg)
994-
end
995-
num = is_stiff ? 2 : 1
996-
if num == 1
997-
return alg.algs[1]
998+
if length(alg.algs) > 2
999+
alg.algs[alg.choice_function.current]
9981000
else
999-
return alg.algs[2]
1001+
if is_stiff === nothing
1002+
throwautoswitch(alg)
1003+
end
1004+
num = is_stiff ? 2 : 1
1005+
if num == 1
1006+
return alg.algs[1]
1007+
else
1008+
return alg.algs[2]
1009+
end
10001010
end
10011011
else
10021012
return _eval_index(identity, alg.algs, integrator.cache.current)
@@ -1071,3 +1081,5 @@ is_mass_matrix_alg(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = false
10711081
is_mass_matrix_alg(alg::CompositeAlgorithm) = all(is_mass_matrix_alg, alg.algs)
10721082
is_mass_matrix_alg(alg::RosenbrockAlgorithm) = true
10731083
is_mass_matrix_alg(alg::NewtonAlgorithm) = !isesdirk(alg)
1084+
# hack for the default alg
1085+
is_mass_matrix_alg(alg::CompositeAlgorithm{<:Any, <:Tuple{Tsit5, Vern7, Rosenbrock23, Rodas5P, FBDF, FBDF}}) = true

src/algorithms.jl

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2969,7 +2969,7 @@ Scientific Computing, 18 (1), pp. 1-22.
29692969
differential-algebraic problems. Computational mathematics (2nd revised ed.), Springer (1996)
29702970
29712971
#### ROS2PR, ROS2S, ROS3PR, Scholz4_7
2972-
-Rang, Joachim (2014): The Prothero and Robinson example:
2972+
-Rang, Joachim (2014): The Prothero and Robinson example:
29732973
Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods.
29742974
https://doi.org/10.24355/dbbs.084-201408121139-0
29752975
@@ -3014,16 +3014,16 @@ University of Geneva, Switzerland.
30143014
https://doi.org/10.1016/j.cam.2015.03.010
30153015
30163016
#### ROS3PRL, ROS3PRL2
3017-
-Rang, Joachim (2014): The Prothero and Robinson example:
3017+
-Rang, Joachim (2014): The Prothero and Robinson example:
30183018
Convergence studies for Runge-Kutta and Rosenbrock-Wanner methods.
30193019
https://doi.org/10.24355/dbbs.084-201408121139-0
30203020
30213021
#### Rodas5P
3022-
- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package.
3022+
- Steinebach G. Construction of Rosenbrock–Wanner method Rodas5P and numerical benchmarks within the Julia Differential Equations package.
30233023
In: BIT Numerical Mathematics, 63(2), 2023
30243024
30253025
#### Rodas23W, Rodas3P, Rodas5Pe, Rodas5Pr
3026-
- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications -
3026+
- Steinebach G. Rosenbrock methods within OrdinaryDiffEq.jl - Overview, recent developments and applications -
30273027
Preprint 2024
30283028
https://github.com/hbrs-cse/RosenbrockMethods/blob/main/paper/JuliaPaper.pdf
30293029
@@ -3239,9 +3239,13 @@ end
32393239

32403240
#########################################
32413241

3242-
struct CompositeAlgorithm{T, F} <: OrdinaryDiffEqCompositeAlgorithm
3242+
struct CompositeAlgorithm{CS, T, F} <: OrdinaryDiffEqCompositeAlgorithm
32433243
algs::T
32443244
choice_function::F
3245+
function CompositeAlgorithm(algs::T, choice_function::F) where {T,F}
3246+
CS = mapreduce(alg->has_chunksize(alg) ? get_chunksize_int(alg) : 0, max, algs)
3247+
new{CS, T, F}(algs, choice_function)
3248+
end
32453249
end
32463250

32473251
TruncatedStacktraces.@truncate_stacktrace CompositeAlgorithm 1
@@ -3250,6 +3254,62 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
32503254
Base.Experimental.silence!(CompositeAlgorithm)
32513255
end
32523256

3257+
mutable struct AutoSwitchCache{nAlg, sAlg, tolType, T}
3258+
count::Int
3259+
successive_switches::Int
3260+
nonstiffalg::nAlg
3261+
stiffalg::sAlg
3262+
is_stiffalg::Bool
3263+
maxstiffstep::Int
3264+
maxnonstiffstep::Int
3265+
nonstifftol::tolType
3266+
stifftol::tolType
3267+
dtfac::T
3268+
stiffalgfirst::Bool
3269+
switch_max::Int
3270+
current::Int
3271+
function AutoSwitchCache(count::Int,
3272+
successive_switches::Int,
3273+
nonstiffalg::nAlg,
3274+
stiffalg::sAlg,
3275+
is_stiffalg::Bool,
3276+
maxstiffstep::Int,
3277+
maxnonstiffstep::Int,
3278+
nonstifftol::tolType,
3279+
stifftol::tolType,
3280+
dtfac::T,
3281+
stiffalgfirst::Bool,
3282+
switch_max::Int,
3283+
current::Int=0) where {nAlg, sAlg, tolType, T}
3284+
new{nAlg, sAlg, tolType, T}(count,
3285+
successive_switches,
3286+
nonstiffalg,
3287+
stiffalg,
3288+
is_stiffalg,
3289+
maxstiffstep,
3290+
maxnonstiffstep,
3291+
nonstifftol,
3292+
stifftol,
3293+
dtfac,
3294+
stiffalgfirst,
3295+
switch_max,
3296+
current)
3297+
end
3298+
3299+
end
3300+
3301+
struct AutoSwitch{nAlg, sAlg, tolType, T}
3302+
nonstiffalg::nAlg
3303+
stiffalg::sAlg
3304+
maxstiffstep::Int
3305+
maxnonstiffstep::Int
3306+
nonstifftol::tolType
3307+
stifftol::tolType
3308+
dtfac::T
3309+
stiffalgfirst::Bool
3310+
switch_max::Int
3311+
end
3312+
32533313
################################################################################
32543314
"""
32553315
MEBDF2: Multistep Method

src/cache_utils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@ function DiffEqBase.unwrap_cache(integrator::ODEIntegrator, is_stiff)
88
iscomp = alg isa CompositeAlgorithm
99
if !iscomp
1010
return cache
11+
elseif cache isa DefaultCache
12+
current = integrator.cache.current
13+
if current == 1
14+
return cache.cache1
15+
elseif current == 2
16+
return cache.cache2
17+
elseif current == 3
18+
return cache.cache3
19+
elseif current == 4
20+
return cache.cache4
21+
elseif current == 5
22+
return cache.cache5
23+
elseif current == 6
24+
return cache.cache6
25+
end
1126
elseif alg.choice_function isa AutoSwitch
1227
num = is_stiff ? 2 : 1
1328
return cache.caches[num]

src/caches/basic_caches.jl

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,26 @@ end
1212

1313
TruncatedStacktraces.@truncate_stacktrace CompositeCache 1
1414

15-
if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
16-
Base.Experimental.silence!(CompositeCache)
15+
mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F} <: OrdinaryDiffEqCache
16+
args::A
17+
choice_function::F
18+
current::Int
19+
cache1::T1
20+
cache2::T2
21+
cache3::T3
22+
cache4::T4
23+
cache5::T5
24+
cache6::T6
25+
function DefaultCache{T1, T2, T3, T4, T5, T6, F}(args, choice_function, current) where {T1, T2, T3, T4, T5, T6, F}
26+
new{T1, T2, T3, T4, T5, T6, typeof(args), F}(args, choice_function, current)
27+
end
1728
end
1829

19-
function alg_cache(alg::CompositeAlgorithm{Tuple{T1, T2}, F}, u, rate_prototype,
20-
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits},
21-
::Type{tTypeNoUnits}, uprev,
22-
uprev2, f, t, dt, reltol, p, calck,
23-
::Val{V}) where {T1, T2, F, V, uEltypeNoUnits, uBottomEltypeNoUnits,
24-
tTypeNoUnits}
25-
caches = (
26-
alg_cache(alg.algs[1], u, rate_prototype, uEltypeNoUnits,
27-
uBottomEltypeNoUnits,
28-
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)),
29-
alg_cache(alg.algs[2], u, rate_prototype, uEltypeNoUnits,
30-
uBottomEltypeNoUnits,
31-
tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val(V)))
32-
CompositeCache(caches, alg.choice_function, 1)
30+
TruncatedStacktraces.@truncate_stacktrace DefaultCache 1
31+
32+
if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
33+
Base.Experimental.silence!(CompositeCache)
34+
Base.Experimental.silence!(DefaultCache)
3335
end
3436

3537
function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -41,6 +43,24 @@ function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoU
4143
CompositeCache(caches, alg.choice_function, 1)
4244
end
4345

46+
function alg_cache(alg::CompositeAlgorithm{CS, Tuple{A1, A2, A3, A4, A5, A6}}, u,
47+
rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits},
48+
uprev, uprev2, f, t, dt, reltol, p, calck,
49+
::Val{V}) where {CS, V, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, A1, A2, A3, A4, A5, A6}
50+
51+
args = (u, rate_prototype, uEltypeNoUnits,
52+
uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt,
53+
reltol, p, calck, Val(V))
54+
argT = map(typeof, args)
55+
T1 = Base.promote_op(alg_cache, A1, argT...)
56+
T2 = Base.promote_op(alg_cache, A2, argT...)
57+
T3 = Base.promote_op(alg_cache, A3, argT...)
58+
T4 = Base.promote_op(alg_cache, A4, argT...)
59+
T5 = Base.promote_op(alg_cache, A5, argT...)
60+
T6 = Base.promote_op(alg_cache, A6, argT...)
61+
DefaultCache{T1, T2, T3, T4, T5, T6, typeof(alg.choice_function)}(args, alg.choice_function, 1)
62+
end
63+
4464
# map + closure approach doesn't infer
4565
@generated function __alg_cache(algs::T, u, rate_prototype, ::Type{uEltypeNoUnits},
4666
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev,

0 commit comments

Comments
 (0)