Skip to content

Commit 25bca1d

Browse files
Update to use DiffEqBase initialization algorithms
- Import DefaultInit, ShampineCollocationInit, BrownBasicInit from DiffEqBase - Add extended versions for OrdinaryDiffEq-specific options - Update OrdinaryDiffEqNonlinearSolve to accept both base and extended types - Add helper functions to handle algorithm parameter differences - Bump DiffEqBase to 6.190 for new initialization algorithms
1 parent 9088786 commit 25bca1d

File tree

3 files changed

+108
-34
lines changed

3 files changed

+108
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ Adapt = "4.3"
112112
ArrayInterface = "7.19"
113113
CommonSolve = "0.2.4"
114114
DataStructures = "0.18.22, 0.19"
115-
DiffEqBase = "6.186"
115+
DiffEqBase = "6.190"
116116
DocStringExtensions = "0.9.5"
117117
EnumX = "1.0.5"
118118
ExplicitImports = "1.13.1"

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
1-
struct DefaultInit <: SciMLBase.DAEInitializationAlgorithm end
1+
# Import standard initialization algorithms from DiffEqBase
2+
import DiffEqBase: DefaultInit, ShampineCollocationInit, BrownBasicInit
23

3-
struct ShampineCollocationInit{T, F} <: SciMLBase.DAEInitializationAlgorithm
4+
# Re-export for backward compatibility
5+
export DefaultInit, ShampineCollocationInit, BrownBasicInit
6+
7+
# Extended versions with OrdinaryDiffEq-specific options
8+
struct ShampineCollocationInitExt{T, F} <: SciMLBase.DAEInitializationAlgorithm
49
initdt::T
510
nlsolve::F
611
end
7-
function ShampineCollocationInit(; initdt = nothing, nlsolve = nothing)
8-
ShampineCollocationInit(initdt, nlsolve)
12+
function ShampineCollocationInitExt(; initdt = nothing, nlsolve = nothing)
13+
ShampineCollocationInitExt(initdt, nlsolve)
914
end
10-
function ShampineCollocationInit(initdt)
11-
ShampineCollocationInit(; initdt = initdt, nlsolve = nothing)
15+
function ShampineCollocationInitExt(initdt)
16+
ShampineCollocationInitExt(; initdt = initdt, nlsolve = nothing)
1217
end
1318

19+
# Constructor that delegates to extended version for backward compatibility
20+
function ShampineCollocationInit(; initdt = nothing, nlsolve = nothing)
21+
if initdt !== nothing || nlsolve !== nothing
22+
ShampineCollocationInitExt(initdt, nlsolve)
23+
else
24+
DiffEqBase.ShampineCollocationInit()
25+
end
26+
end
27+
ShampineCollocationInit(initdt) = ShampineCollocationInitExt(; initdt = initdt, nlsolve = nothing)
28+
1429
struct BrownFullBasicInit{T, F} <: SciMLBase.DAEInitializationAlgorithm
1530
abstol::T
1631
nlsolve::F
@@ -20,6 +35,18 @@ function BrownFullBasicInit(; abstol = 1e-10, nlsolve = nothing)
2035
end
2136
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)
2237

38+
# Alias for consistency with DiffEqBase naming
39+
const BrownBasicInitExt = BrownFullBasicInit
40+
41+
# Constructor that delegates for backward compatibility
42+
function BrownBasicInit(; abstol = nothing, nlsolve = nothing)
43+
if abstol !== nothing || nlsolve !== nothing
44+
BrownFullBasicInit(something(abstol, 1e-10), nlsolve)
45+
else
46+
DiffEqBase.BrownBasicInit()
47+
end
48+
end
49+
2350
## Notes
2451

2552
#=
@@ -177,3 +204,20 @@ function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::CheckInit,
177204
SciMLBase.get_initial_values(
178205
prob, integrator, prob.f, alg, isinplace; abstol = integrator.opts.abstol)
179206
end
207+
208+
# Delegate base DiffEqBase types to extended versions with default options
209+
function _initialize_dae!(integrator, prob::AbstractDEProblem,
210+
alg::DiffEqBase.ShampineCollocationInit, isinplace::Union{Val{true}, Val{false}})
211+
_initialize_dae!(integrator, prob, ShampineCollocationInitExt(nothing, nothing), isinplace)
212+
end
213+
214+
function _initialize_dae!(integrator, prob::AbstractDEProblem,
215+
alg::DiffEqBase.BrownBasicInit, isinplace::Union{Val{true}, Val{false}})
216+
_initialize_dae!(integrator, prob, BrownFullBasicInit(integrator.opts.abstol, nothing), isinplace)
217+
end
218+
219+
# Handle DiffEqBase.DefaultInit same as our DefaultInit
220+
function _initialize_dae!(integrator, prob::AbstractDEProblem,
221+
alg::DiffEqBase.DefaultInit, isinplace::Union{Val{true}, Val{false}})
222+
_initialize_dae!(integrator, prob, DefaultInit(), isinplace)
223+
end

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,33 @@ end
1616
end
1717
end
1818

19+
# Helper to get abstol from either the algorithm or integrator opts
20+
@inline function get_abstol(alg::Union{BrownFullBasicInit, OrdinaryDiffEqCore.BrownFullBasicInit}, integrator)
21+
return alg.abstol
22+
end
23+
24+
@inline function get_abstol(alg::DiffEqBase.BrownBasicInit, integrator)
25+
return integrator.opts.abstol
26+
end
27+
28+
# Helper to get nlsolve from either the algorithm or nothing
29+
@inline function get_nlsolve(alg::Union{BrownFullBasicInit, OrdinaryDiffEqCore.BrownFullBasicInit, OrdinaryDiffEqCore.ShampineCollocationInitExt})
30+
return alg.nlsolve
31+
end
32+
33+
@inline function get_nlsolve(alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.ShampineCollocationInit})
34+
return nothing
35+
end
36+
37+
# Helper to get initdt from either the algorithm or nothing
38+
@inline function get_initdt(alg::OrdinaryDiffEqCore.ShampineCollocationInitExt)
39+
return alg.initdt
40+
end
41+
42+
@inline function get_initdt(alg::Union{DiffEqBase.ShampineCollocationInit, OrdinaryDiffEqCore.ShampineCollocationInit})
43+
return nothing
44+
end
45+
1946
function default_nlsolve(
2047
::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false)
2148
FastShortcutNonlinearPolyalg(;
@@ -53,19 +80,20 @@ Solve for `u`
5380
5481
=#
5582

56-
function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocationInit,
83+
function _initialize_dae!(integrator, prob::ODEProblem, alg::Union{ShampineCollocationInit, OrdinaryDiffEqCore.ShampineCollocationInitExt},
5784
isinplace::Val{true})
5885
@unpack p, t, f = integrator
5986
M = integrator.f.mass_matrix
6087
dtmax = integrator.opts.dtmax
6188
tmp = first(get_tmp_cache(integrator))
6289
u0 = integrator.u
6390

64-
dt = if alg.initdt === nothing
91+
initdt = get_initdt(alg)
92+
dt = if initdt === nothing
6593
integrator.dt != 0 ? min(integrator.dt / 5, dtmax) :
6694
(prob.tspan[end] - prob.tspan[begin]) / 1000 # Haven't implemented norm reduction
6795
else
68-
alg.initdt
96+
initdt
6997
end
7098

7199
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
@@ -77,7 +105,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
77105

78106
check_dae_tolerance(integrator, tmp, integrator.opts.abstol, t, isinplace) && return
79107

80-
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
108+
if isdefined(integrator.cache, :nlsolver) && !isnothing(get_nlsolve(alg))
81109
# backward Euler
82110
nlsolver = integrator.cache.nlsolver
83111
oldγ, oldc, oldmethod,
@@ -149,7 +177,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
149177
jac_prototype = f.jac_prototype,
150178
jac = jac)
151179
nlprob = NonlinearProblem(nlfunc, integrator.u, p)
152-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, nlprob, isAD)
180+
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, u0, nlprob, isAD)
153181
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
154182
reltol = integrator.opts.reltol)
155183
integrator.u .= nlsol.u
@@ -168,18 +196,19 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
168196
return
169197
end
170198

171-
function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocationInit,
199+
function _initialize_dae!(integrator, prob::ODEProblem, alg::Union{ShampineCollocationInit, OrdinaryDiffEqCore.ShampineCollocationInitExt},
172200
isinplace::Val{false})
173201
@unpack p, t, f = integrator
174202
u0 = integrator.u
175203
M = integrator.f.mass_matrix
176204
dtmax = integrator.opts.dtmax
177205

178-
dt = if alg.initdt === nothing
206+
initdt = get_initdt(alg)
207+
dt = if initdt === nothing
179208
integrator.dt != 0 ? min(integrator.dt / 5, dtmax) :
180209
(prob.tspan[end] - prob.tspan[begin]) / 1000 # Haven't implemented norm reduction
181210
else
182-
alg.initdt
211+
initdt
183212
end
184213

185214
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
@@ -191,7 +220,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
191220

192221
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return
193222

194-
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
223+
if isdefined(integrator.cache, :nlsolver) && !isnothing(get_nlsolve(alg))
195224
# backward Euler
196225
nlsolver = integrator.cache.nlsolver
197226
oldγ, oldc, oldmethod,
@@ -225,7 +254,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
225254
jac_prototype = f.jac_prototype,
226255
jac = jac)
227256
nlprob = NonlinearProblem(nlfunc, u0)
228-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, nlprob, u0)
257+
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, nlprob, u0)
229258

230259
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
231260
reltol = integrator.opts.reltol)
@@ -306,7 +335,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
306335
jac_prototype = f.jac_prototype,
307336
jac = jac)
308337
nlprob = NonlinearProblem(nlfunc, u0, p)
309-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, nlprob, isAD)
338+
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, u0, nlprob, isAD)
310339
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
311340
reltol = integrator.opts.reltol)
312341

@@ -350,7 +379,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
350379
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype,
351380
jac = jac)
352381
nlprob = NonlinearProblem(nlfunc, u0)
353-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, nlprob, u0)
382+
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, nlprob, u0)
354383

355384
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype)
356385
nlprob = NonlinearProblem(nlfunc, u0)
@@ -388,7 +417,7 @@ function algebraic_jacobian(jac_prototype::T, algebraic_eqs,
388417
end
389418

390419
function _initialize_dae!(integrator, prob::ODEProblem,
391-
alg::BrownFullBasicInit, isinplace::Val{true})
420+
alg::Union{BrownFullBasicInit, DiffEqBase.BrownBasicInit}, isinplace::Val{true})
392421
@unpack p, t, f = integrator
393422
u = integrator.u
394423
M = integrator.f.mass_matrix
@@ -403,7 +432,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
403432

404433
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
405434

406-
check_dae_tolerance(integrator, tmp, alg.abstol, t, isinplace) && return
435+
check_dae_tolerance(integrator, tmp, get_abstol(alg, integrator), t, isinplace) && return
407436
alg_u = @view u[algebraic_vars]
408437

409438
# These non-dual values are thus used to make the caches
@@ -451,9 +480,9 @@ function _initialize_dae!(integrator, prob::ODEProblem,
451480
J = algebraic_jacobian(f.jac_prototype, algebraic_eqs, algebraic_vars)
452481
nlfunc = NonlinearFunction(nlequation!; jac_prototype = J)
453482
nlprob = NonlinearProblem(nlfunc, alg_u, p)
454-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u, nlprob, isAD)
483+
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, u, nlprob, isAD)
455484

456-
nlsol = solve(nlprob, nlsolve; abstol = alg.abstol, reltol = integrator.opts.reltol)
485+
nlsol = solve(nlprob, nlsolve; abstol = get_abstol(alg, integrator), reltol = integrator.opts.reltol)
457486
alg_u .= nlsol
458487

459488
recursivecopy!(integrator.uprev, integrator.u)
@@ -469,7 +498,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
469498
end
470499

471500
function _initialize_dae!(integrator, prob::ODEProblem,
472-
alg::BrownFullBasicInit, isinplace::Val{false})
501+
alg::Union{BrownFullBasicInit, DiffEqBase.BrownBasicInit}, isinplace::Val{false})
473502
@unpack p, t, f = integrator
474503

475504
u0 = integrator.u
@@ -482,7 +511,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
482511
du = f(u0, p, t)
483512
resid = _vec(du)[algebraic_eqs]
484513

485-
check_dae_tolerance(integrator, resid, alg.abstol, t, isinplace) && return
514+
check_dae_tolerance(integrator, resid, get_abstol(alg, integrator), t, isinplace) && return
486515

487516
isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff
488517
if isAD
@@ -511,7 +540,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
511540
J = algebraic_jacobian(f.jac_prototype, algebraic_eqs, algebraic_vars)
512541
nlfunc = NonlinearFunction(nlequation; jac_prototype = J)
513542
nlprob = NonlinearProblem(nlfunc, u0[algebraic_vars])
514-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, nlprob, isAD)
543+
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, u0, nlprob, isAD)
515544

516545
nlsol = solve(nlprob, nlsolve)
517546

@@ -537,7 +566,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
537566
end
538567

539568
function _initialize_dae!(integrator, prob::DAEProblem,
540-
alg::BrownFullBasicInit, isinplace::Val{true})
569+
alg::Union{BrownFullBasicInit, DiffEqBase.BrownBasicInit}, isinplace::Val{true})
541570
@unpack p, t, f = integrator
542571
differential_vars = prob.differential_vars
543572
u = integrator.u
@@ -561,7 +590,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
561590
normtmp = get_tmp_cache(integrator)[1]
562591
f(normtmp, du, u, p, t)
563592

564-
if check_dae_tolerance(integrator, normtmp, alg.abstol, t, isinplace)
593+
if check_dae_tolerance(integrator, normtmp, get_abstol(alg, integrator), t, isinplace)
565594
return
566595
elseif differential_vars === nothing
567596
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")
@@ -592,15 +621,16 @@ function _initialize_dae!(integrator, prob::DAEProblem,
592621
f(out, du_tmp, uu, p, t)
593622
end
594623

595-
if alg.nlsolve !== nothing
596-
nlsolve = alg.nlsolve
624+
nlsolve_alg = get_nlsolve(alg)
625+
if nlsolve_alg !== nothing
626+
nlsolve = nlsolve_alg
597627
else
598628
nlsolve = NewtonRaphson(autodiff = alg_autodiff(integrator.alg))
599629
end
600630

601631
nlfunc = NonlinearFunction(nlequation!; jac_prototype = f.jac_prototype)
602632
nlprob = NonlinearProblem(nlfunc, ifelse.(differential_vars, du, u), p)
603-
nlsol = solve(nlprob, nlsolve; abstol = alg.abstol, reltol = integrator.opts.reltol)
633+
nlsol = solve(nlprob, nlsolve; abstol = get_abstol(alg, integrator), reltol = integrator.opts.reltol)
604634

605635
@. du = ifelse(differential_vars, nlsol.u, du)
606636
@. u = ifelse(differential_vars, u, nlsol.u)
@@ -618,12 +648,12 @@ function _initialize_dae!(integrator, prob::DAEProblem,
618648
end
619649

620650
function _initialize_dae!(integrator, prob::DAEProblem,
621-
alg::BrownFullBasicInit, isinplace::Val{false})
651+
alg::Union{BrownFullBasicInit, DiffEqBase.BrownBasicInit}, isinplace::Val{false})
622652
@unpack p, t, f = integrator
623653
differential_vars = prob.differential_vars
624654

625655
if check_dae_tolerance(
626-
integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t, isinplace)
656+
integrator, f(integrator.du, integrator.u, p, t), get_abstol(alg, integrator), t, isinplace)
627657
return
628658
elseif differential_vars === nothing
629659
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")
@@ -647,7 +677,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
647677
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype)
648678
nlprob = NonlinearProblem(nlfunc, ifelse.(differential_vars, du, u))
649679

650-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, nlprob, integrator.u)
680+
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, nlprob, integrator.u)
651681

652682
@show nlsolve
653683

0 commit comments

Comments
 (0)