Skip to content

Commit 3f289c3

Browse files
Use DiffEqBase extended initialization algorithms
Now that DiffEqBase has the extended versions of BrownBasicInit and ShampineCollocationInit with parameters, OrdinaryDiffEq can use them directly instead of maintaining its own extended types. Changes: - Import BrownFullBasicInit from DiffEqBase - Remove local extended type definitions - Create aliases for backward compatibility - Update OrdinaryDiffEqNonlinearSolve to use DiffEqBase types This simplifies the code and reduces duplication across the ecosystem. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 355d125 commit 3f289c3

File tree

2 files changed

+15
-56
lines changed

2 files changed

+15
-56
lines changed

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,13 @@
1-
# Import standard initialization algorithms from DiffEqBase
2-
import DiffEqBase: DefaultInit, ShampineCollocationInit, BrownBasicInit
1+
# Import all initialization algorithms from DiffEqBase
2+
import DiffEqBase: DefaultInit, ShampineCollocationInit, BrownBasicInit, BrownFullBasicInit
33

44
# Re-export for backward compatibility
5-
export DefaultInit, ShampineCollocationInit, BrownBasicInit
5+
export DefaultInit, ShampineCollocationInit, BrownBasicInit, BrownFullBasicInit
66

7-
# Extended versions with OrdinaryDiffEq-specific options
8-
struct ShampineCollocationInitExt{T, F} <: SciMLBase.DAEInitializationAlgorithm
9-
initdt::T
10-
nlsolve::F
11-
end
12-
function ShampineCollocationInitExt(; initdt = nothing, nlsolve = nothing)
13-
ShampineCollocationInitExt(initdt, nlsolve)
14-
end
15-
function ShampineCollocationInitExt(initdt)
16-
ShampineCollocationInitExt(; initdt = initdt, nlsolve = nothing)
17-
end
18-
19-
# Constructor for backward compatibility when passing initdt
20-
function ShampineCollocationInit(initdt::T) where T
21-
ShampineCollocationInitExt(; initdt = initdt, nlsolve = nothing)
22-
end
23-
24-
struct BrownFullBasicInit{T, F} <: SciMLBase.DAEInitializationAlgorithm
25-
abstol::T
26-
nlsolve::F
27-
end
28-
function BrownFullBasicInit(; abstol = 1e-10, nlsolve = nothing)
29-
BrownFullBasicInit(abstol, nlsolve)
30-
end
31-
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)
32-
33-
# Alias for consistency with DiffEqBase naming
7+
# Legacy aliases for backward compatibility
8+
const ShampineCollocationInitExt = ShampineCollocationInit
349
const BrownBasicInitExt = BrownFullBasicInit
35-
36-
# Constructor for backward compatibility when passing abstol
37-
function BrownBasicInit(abstol::T) where T
38-
BrownFullBasicInit(abstol, nothing)
39-
end
10+
export ShampineCollocationInitExt, BrownBasicInitExt
4011

4112
## Notes
4213

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,20 @@ end
1717
end
1818

1919
# Helper to get abstol from either the algorithm or integrator opts
20-
@inline function get_abstol(alg::Union{BrownFullBasicInit, OrdinaryDiffEqCore.BrownFullBasicInit}, integrator)
20+
@inline function get_abstol(alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, integrator)
2121
return alg.abstol
2222
end
2323

24-
@inline function get_abstol(alg::DiffEqBase.BrownBasicInit, integrator)
25-
return integrator.opts.abstol
26-
end
27-
2824
# Helper to get nlsolve from either the algorithm or nothing
29-
@inline function get_nlsolve(alg::Union{BrownFullBasicInit, OrdinaryDiffEqCore.BrownFullBasicInit, OrdinaryDiffEqCore.ShampineCollocationInitExt})
25+
@inline function get_nlsolve(alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit, DiffEqBase.ShampineCollocationInit})
3026
return alg.nlsolve
3127
end
3228

33-
@inline function get_nlsolve(alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.ShampineCollocationInit})
34-
return nothing
35-
end
36-
3729
# Helper to get initdt from either the algorithm or nothing
38-
@inline function get_initdt(alg::OrdinaryDiffEqCore.ShampineCollocationInitExt)
30+
@inline function get_initdt(alg::DiffEqBase.ShampineCollocationInit)
3931
return alg.initdt
4032
end
4133

42-
@inline function get_initdt(alg::Union{DiffEqBase.ShampineCollocationInit, OrdinaryDiffEqCore.ShampineCollocationInit})
43-
return nothing
44-
end
45-
4634
function default_nlsolve(
4735
::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false)
4836
FastShortcutNonlinearPolyalg(;
@@ -80,7 +68,7 @@ Solve for `u`
8068
8169
=#
8270

83-
function _initialize_dae!(integrator, prob::ODEProblem, alg::Union{ShampineCollocationInit, OrdinaryDiffEqCore.ShampineCollocationInitExt},
71+
function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.ShampineCollocationInit,
8472
isinplace::Val{true})
8573
@unpack p, t, f = integrator
8674
M = integrator.f.mass_matrix
@@ -196,7 +184,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::Union{ShampineCollo
196184
return
197185
end
198186

199-
function _initialize_dae!(integrator, prob::ODEProblem, alg::Union{ShampineCollocationInit, OrdinaryDiffEqCore.ShampineCollocationInitExt},
187+
function _initialize_dae!(integrator, prob::ODEProblem, alg::DiffEqBase.ShampineCollocationInit,
200188
isinplace::Val{false})
201189
@unpack p, t, f = integrator
202190
u0 = integrator.u
@@ -417,7 +405,7 @@ function algebraic_jacobian(jac_prototype::T, algebraic_eqs,
417405
end
418406

419407
function _initialize_dae!(integrator, prob::ODEProblem,
420-
alg::Union{BrownFullBasicInit, DiffEqBase.BrownBasicInit}, isinplace::Val{true})
408+
alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, isinplace::Val{true})
421409
@unpack p, t, f = integrator
422410
u = integrator.u
423411
M = integrator.f.mass_matrix
@@ -498,7 +486,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
498486
end
499487

500488
function _initialize_dae!(integrator, prob::ODEProblem,
501-
alg::Union{BrownFullBasicInit, DiffEqBase.BrownBasicInit}, isinplace::Val{false})
489+
alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, isinplace::Val{false})
502490
@unpack p, t, f = integrator
503491

504492
u0 = integrator.u
@@ -566,7 +554,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
566554
end
567555

568556
function _initialize_dae!(integrator, prob::DAEProblem,
569-
alg::Union{BrownFullBasicInit, DiffEqBase.BrownBasicInit}, isinplace::Val{true})
557+
alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, isinplace::Val{true})
570558
@unpack p, t, f = integrator
571559
differential_vars = prob.differential_vars
572560
u = integrator.u
@@ -648,7 +636,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
648636
end
649637

650638
function _initialize_dae!(integrator, prob::DAEProblem,
651-
alg::Union{BrownFullBasicInit, DiffEqBase.BrownBasicInit}, isinplace::Val{false})
639+
alg::Union{DiffEqBase.BrownBasicInit, DiffEqBase.BrownFullBasicInit}, isinplace::Val{false})
652640
@unpack p, t, f = integrator
653641
differential_vars = prob.differential_vars
654642

0 commit comments

Comments
 (0)