Skip to content

Commit 85bc8dd

Browse files
Clean up OrdinaryDiffEq initialization algorithm changes
- Remove unnecessary backward compatibility aliases (ShampineCollocationInitExt, BrownBasicInitExt) - Remove unnecessary getter functions - use direct field access instead - Add ODEIntegrator type dispatch to disambiguate from DiffEqBase implementations - Remove redundant DefaultInit delegation This keeps the PR minimal and focused on using the DiffEqBase extended algorithms. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 2f08bb8 commit 85bc8dd

File tree

2 files changed

+33
-59
lines changed

2 files changed

+33
-59
lines changed

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
# Import all initialization algorithms from DiffEqBase
22
import DiffEqBase: DefaultInit, ShampineCollocationInit, BrownBasicInit, BrownFullBasicInit
33

4-
# Re-export for backward compatibility
4+
# Re-export for convenience
55
export DefaultInit, ShampineCollocationInit, BrownBasicInit, BrownFullBasicInit
66

7-
# Legacy aliases for backward compatibility
8-
const ShampineCollocationInitExt = ShampineCollocationInit
9-
const BrownBasicInitExt = BrownFullBasicInit
10-
export ShampineCollocationInitExt, BrownBasicInitExt
11-
127
## Notes
138

149
#=
@@ -33,7 +28,7 @@ end
3328

3429
## Default algorithms
3530

36-
function _initialize_dae!(integrator, prob::ODEProblem,
31+
function _initialize_dae!(integrator::ODEIntegrator, prob::ODEProblem,
3732
alg::DefaultInit, x::Union{Val{true}, Val{false}})
3833
if SciMLBase.has_initializeprob(prob.f)
3934
_initialize_dae!(integrator, prob,
@@ -47,7 +42,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
4742
end
4843
end
4944

50-
function _initialize_dae!(integrator, prob::DAEProblem,
45+
function _initialize_dae!(integrator::ODEIntegrator, prob::DAEProblem,
5146
alg::DefaultInit, x::Union{Val{true}, Val{false}})
5247
if SciMLBase.has_initializeprob(prob.f)
5348
_initialize_dae!(integrator, prob,
@@ -66,7 +61,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
6661
end
6762
end
6863

69-
function _initialize_dae!(integrator, prob::DiscreteProblem,
64+
function _initialize_dae!(integrator::ODEIntegrator, prob::DiscreteProblem,
7065
alg::DefaultInit, x::Union{Val{true}, Val{false}})
7166
if SciMLBase.has_initializeprob(prob.f)
7267
# integrator.opts.abstol is `false` for `DiscreteProblem`.
@@ -113,13 +108,13 @@ end
113108

114109
## NoInit
115110

116-
function _initialize_dae!(integrator, prob::AbstractDEProblem,
111+
function _initialize_dae!(integrator::ODEIntegrator, prob::AbstractDEProblem,
117112
alg::NoInit, x::Union{Val{true}, Val{false}})
118113
end
119114

120115
## OverrideInit
121116

122-
function _initialize_dae!(integrator, prob::AbstractDEProblem,
117+
function _initialize_dae!(integrator::ODEIntegrator, prob::AbstractDEProblem,
123118
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
124119
initializeprob = prob.f.initialization_data.initializeprob
125120

@@ -161,17 +156,10 @@ function _initialize_dae!(integrator, prob::AbstractDEProblem,
161156
end
162157

163158
## CheckInit
164-
function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::CheckInit,
159+
function _initialize_dae!(integrator::ODEIntegrator, prob::AbstractDEProblem, alg::CheckInit,
165160
isinplace::Union{Val{true}, Val{false}})
166161
SciMLBase.get_initial_values(
167162
prob, integrator, prob.f, alg, isinplace; abstol = integrator.opts.abstol)
168163
end
169164

170-
# Delegate base DiffEqBase types to extended versions with default options
171165
# No longer needed - DiffEqBase types now have the parameters directly
172-
173-
# Handle DiffEqBase.DefaultInit same as our DefaultInit
174-
function _initialize_dae!(integrator, prob::AbstractDEProblem,
175-
alg::DiffEqBase.DefaultInit, isinplace::Union{Val{true}, Val{false}})
176-
_initialize_dae!(integrator, prob, DefaultInit(), isinplace)
177-
end

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,6 @@ end
1616
end
1717
end
1818

19-
# Helper to get abstol from either the algorithm or integrator opts
20-
@inline function get_abstol(alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, integrator)
21-
return alg.abstol
22-
end
23-
24-
# Helper to get nlsolve from either the algorithm or nothing
25-
@inline function get_nlsolve(alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit, DiffEqBase.ShampineCollocationInit})
26-
return alg.nlsolve
27-
end
28-
29-
# Helper to get initdt from either the algorithm or nothing
30-
@inline function get_initdt(alg::DiffEqBase.ShampineCollocationInit)
31-
return alg.initdt
32-
end
3319

3420
function default_nlsolve(
3521
::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false)
@@ -68,15 +54,15 @@ Solve for `u`
6854
6955
=#
7056

71-
function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.ShampineCollocationInit,
57+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::ODEProblem, alg::DiffEqBase.ShampineCollocationInit,
7258
isinplace::Val{true})
7359
@unpack p, t, f = integrator
7460
M = integrator.f.mass_matrix
7561
dtmax = integrator.opts.dtmax
7662
tmp = first(get_tmp_cache(integrator))
7763
u0 = integrator.u
7864

79-
initdt = get_initdt(alg)
65+
initdt = alg.initdt
8066
dt = if initdt === nothing
8167
integrator.dt != 0 ? min(integrator.dt / 5, dtmax) :
8268
(prob.tspan[end] - prob.tspan[begin]) / 1000 # Haven't implemented norm reduction
@@ -93,7 +79,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.Shampine
9379

9480
check_dae_tolerance(integrator, tmp, integrator.opts.abstol, t, isinplace) && return
9581

96-
if isdefined(integrator.cache, :nlsolver) && !isnothing(get_nlsolve(alg))
82+
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
9783
# backward Euler
9884
nlsolver = integrator.cache.nlsolver
9985
oldγ, oldc, oldmethod,
@@ -165,7 +151,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.Shampine
165151
jac_prototype = f.jac_prototype,
166152
jac = jac)
167153
nlprob = NonlinearProblem(nlfunc, integrator.u, p)
168-
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, u0, nlprob, isAD)
154+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, nlprob, isAD)
169155
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
170156
reltol = integrator.opts.reltol)
171157
integrator.u .= nlsol.u
@@ -184,14 +170,14 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.Shampine
184170
return
185171
end
186172

187-
function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.ShampineCollocationInit,
173+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::ODEProblem, alg::DiffEqBase.ShampineCollocationInit,
188174
isinplace::Val{false})
189175
@unpack p, t, f = integrator
190176
u0 = integrator.u
191177
M = integrator.f.mass_matrix
192178
dtmax = integrator.opts.dtmax
193179

194-
initdt = get_initdt(alg)
180+
initdt = alg.initdt
195181
dt = if initdt === nothing
196182
integrator.dt != 0 ? min(integrator.dt / 5, dtmax) :
197183
(prob.tspan[end] - prob.tspan[begin]) / 1000 # Haven't implemented norm reduction
@@ -208,7 +194,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.Shampine
208194

209195
check_dae_tolerance(integrator, resid, integrator.opts.abstol, t, isinplace) && return
210196

211-
if isdefined(integrator.cache, :nlsolver) && !isnothing(get_nlsolve(alg))
197+
if isdefined(integrator.cache, :nlsolver) && !isnothing(alg.nlsolve)
212198
# backward Euler
213199
nlsolver = integrator.cache.nlsolver
214200
oldγ, oldc, oldmethod,
@@ -242,7 +228,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.Shampine
242228
jac_prototype = f.jac_prototype,
243229
jac = jac)
244230
nlprob = NonlinearProblem(nlfunc, u0)
245-
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, nlprob, u0)
231+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, nlprob, u0)
246232

247233
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
248234
reltol = integrator.opts.reltol)
@@ -263,7 +249,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.Shampine
263249
return
264250
end
265251

266-
function _initialize_dae!(integrator, prob::DAEProblem,
252+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::DAEProblem,
267253
alg::ShampineCollocationInit, isinplace::Val{true})
268254
@unpack p, t, f = integrator
269255
u0 = integrator.u
@@ -323,7 +309,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
323309
jac_prototype = f.jac_prototype,
324310
jac = jac)
325311
nlprob = NonlinearProblem(nlfunc, u0, p)
326-
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, u0, nlprob, isAD)
312+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, nlprob, isAD)
327313
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
328314
reltol = integrator.opts.reltol)
329315

@@ -340,7 +326,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
340326
return
341327
end
342328

343-
function _initialize_dae!(integrator, prob::DAEProblem,
329+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::DAEProblem,
344330
alg::ShampineCollocationInit, isinplace::Val{false})
345331
@unpack p, t, f = integrator
346332
u0 = integrator.u
@@ -367,7 +353,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
367353
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype,
368354
jac = jac)
369355
nlprob = NonlinearProblem(nlfunc, u0)
370-
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, nlprob, u0)
356+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, nlprob, u0)
371357

372358
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype)
373359
nlprob = NonlinearProblem(nlfunc, u0)
@@ -404,7 +390,7 @@ function algebraic_jacobian(jac_prototype::T, algebraic_eqs,
404390
jac_prototype[algebraic_eqs, algebraic_vars]
405391
end
406392

407-
function _initialize_dae!(integrator, prob::ODEProblem,
393+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::ODEProblem,
408394
alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, isinplace::Val{true})
409395
@unpack p, t, f = integrator
410396
u = integrator.u
@@ -420,7 +406,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
420406

421407
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
422408

423-
check_dae_tolerance(integrator, tmp, get_abstol(alg, integrator), t, isinplace) && return
409+
check_dae_tolerance(integrator, tmp, alg.abstol, t, isinplace) && return
424410
alg_u = @view u[algebraic_vars]
425411

426412
# These non-dual values are thus used to make the caches
@@ -468,9 +454,9 @@ function _initialize_dae!(integrator, prob::ODEProblem,
468454
J = algebraic_jacobian(f.jac_prototype, algebraic_eqs, algebraic_vars)
469455
nlfunc = NonlinearFunction(nlequation!; jac_prototype = J)
470456
nlprob = NonlinearProblem(nlfunc, alg_u, p)
471-
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, u, nlprob, isAD)
457+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u, nlprob, isAD)
472458

473-
nlsol = solve(nlprob, nlsolve; abstol = get_abstol(alg, integrator), reltol = integrator.opts.reltol)
459+
nlsol = solve(nlprob, nlsolve; abstol = alg.abstol, reltol = integrator.opts.reltol)
474460
alg_u .= nlsol
475461

476462
recursivecopy!(integrator.uprev, integrator.u)
@@ -485,7 +471,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
485471
return
486472
end
487473

488-
function _initialize_dae!(integrator, prob::ODEProblem,
474+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::ODEProblem,
489475
alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, isinplace::Val{false})
490476
@unpack p, t, f = integrator
491477

@@ -499,7 +485,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
499485
du = f(u0, p, t)
500486
resid = _vec(du)[algebraic_eqs]
501487

502-
check_dae_tolerance(integrator, resid, get_abstol(alg, integrator), t, isinplace) && return
488+
check_dae_tolerance(integrator, resid, alg.abstol, t, isinplace) && return
503489

504490
isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff
505491
if isAD
@@ -528,7 +514,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
528514
J = algebraic_jacobian(f.jac_prototype, algebraic_eqs, algebraic_vars)
529515
nlfunc = NonlinearFunction(nlequation; jac_prototype = J)
530516
nlprob = NonlinearProblem(nlfunc, u0[algebraic_vars])
531-
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, u0, nlprob, isAD)
517+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, nlprob, isAD)
532518

533519
nlsol = solve(nlprob, nlsolve)
534520

@@ -553,7 +539,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
553539
return
554540
end
555541

556-
function _initialize_dae!(integrator, prob::DAEProblem,
542+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::DAEProblem,
557543
alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, isinplace::Val{true})
558544
@unpack p, t, f = integrator
559545
differential_vars = prob.differential_vars
@@ -578,7 +564,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
578564
normtmp = get_tmp_cache(integrator)[1]
579565
f(normtmp, du, u, p, t)
580566

581-
if check_dae_tolerance(integrator, normtmp, get_abstol(alg, integrator), t, isinplace)
567+
if check_dae_tolerance(integrator, normtmp, alg.abstol, t, isinplace)
582568
return
583569
elseif differential_vars === nothing
584570
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")
@@ -609,7 +595,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
609595
f(out, du_tmp, uu, p, t)
610596
end
611597

612-
nlsolve_alg = get_nlsolve(alg)
598+
nlsolve_alg = alg.nlsolve
613599
if nlsolve_alg !== nothing
614600
nlsolve = nlsolve_alg
615601
else
@@ -618,7 +604,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
618604

619605
nlfunc = NonlinearFunction(nlequation!; jac_prototype = f.jac_prototype)
620606
nlprob = NonlinearProblem(nlfunc, ifelse.(differential_vars, du, u), p)
621-
nlsol = solve(nlprob, nlsolve; abstol = get_abstol(alg, integrator), reltol = integrator.opts.reltol)
607+
nlsol = solve(nlprob, nlsolve; abstol = alg.abstol, reltol = integrator.opts.reltol)
622608

623609
@. du = ifelse(differential_vars, nlsol.u, du)
624610
@. u = ifelse(differential_vars, u, nlsol.u)
@@ -635,13 +621,13 @@ function _initialize_dae!(integrator, prob::DAEProblem,
635621
return
636622
end
637623

638-
function _initialize_dae!(integrator, prob::DAEProblem,
624+
function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::DAEProblem,
639625
alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, isinplace::Val{false})
640626
@unpack p, t, f = integrator
641627
differential_vars = prob.differential_vars
642628

643629
if check_dae_tolerance(
644-
integrator, f(integrator.du, integrator.u, p, t), get_abstol(alg, integrator), t, isinplace)
630+
integrator, f(integrator.du, integrator.u, p, t), alg.abstol, t, isinplace)
645631
return
646632
elseif differential_vars === nothing
647633
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm.")
@@ -665,7 +651,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
665651
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype)
666652
nlprob = NonlinearProblem(nlfunc, ifelse.(differential_vars, du, u))
667653

668-
nlsolve = default_nlsolve(get_nlsolve(alg), isinplace, nlprob, integrator.u)
654+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, nlprob, integrator.u)
669655

670656
@show nlsolve
671657

0 commit comments

Comments
 (0)