1616 end
1717end
1818
19+ # Helper to get abstol from either the algorithm or integrator opts
20+ @inline function get_abstol (alg:: Union{BrownFullBasicInit, OrdinaryDiffEqCore.BrownFullBasicInit} , integrator)
21+ return alg. abstol
22+ end
23+
24+ @inline function get_abstol (alg:: DiffEqBase.BrownBasicInit , integrator)
25+ return integrator. opts. abstol
26+ end
27+
28+ # Helper to get nlsolve from either the algorithm or nothing
29+ @inline function get_nlsolve (alg:: Union{BrownFullBasicInit, OrdinaryDiffEqCore.BrownFullBasicInit, OrdinaryDiffEqCore.ShampineCollocationInitExt} )
30+ return alg. nlsolve
31+ end
32+
33+ @inline function get_nlsolve (alg:: Union{DiffEqBase.BrownBasicInit, DiffEqBase.ShampineCollocationInit} )
34+ return nothing
35+ end
36+
37+ # Helper to get initdt from either the algorithm or nothing
38+ @inline function get_initdt (alg:: OrdinaryDiffEqCore.ShampineCollocationInitExt )
39+ return alg. initdt
40+ end
41+
42+ @inline function get_initdt (alg:: Union{DiffEqBase.ShampineCollocationInit, OrdinaryDiffEqCore.ShampineCollocationInit} )
43+ return nothing
44+ end
45+
1946function default_nlsolve (
2047 :: Nothing , isinplace:: Val{true} , u, :: AbstractNonlinearProblem , autodiff = false )
2148 FastShortcutNonlinearPolyalg (;
@@ -53,19 +80,20 @@ Solve for `u`
5380
5481=#
5582
56- function _initialize_dae! (integrator, prob:: ODEProblem , alg:: ShampineCollocationInit ,
83+ function _initialize_dae! (integrator, prob:: ODEProblem , alg:: Union{ ShampineCollocationInit, OrdinaryDiffEqCore.ShampineCollocationInitExt} ,
5784 isinplace:: Val{true} )
5885 @unpack p, t, f = integrator
5986 M = integrator. f. mass_matrix
6087 dtmax = integrator. opts. dtmax
6188 tmp = first (get_tmp_cache (integrator))
6289 u0 = integrator. u
6390
64- dt = if alg. initdt === nothing
91+ initdt = get_initdt (alg)
92+ dt = if initdt === nothing
6593 integrator. dt != 0 ? min (integrator. dt / 5 , dtmax) :
6694 (prob. tspan[end ] - prob. tspan[begin ]) / 1000 # Haven't implemented norm reduction
6795 else
68- alg . initdt
96+ initdt
6997 end
7098
7199 algebraic_vars = [all (iszero, x) for x in eachcol (M)]
@@ -77,7 +105,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
77105
78106 check_dae_tolerance (integrator, tmp, integrator. opts. abstol, t, isinplace) && return
79107
80- if isdefined (integrator. cache, :nlsolver ) && ! isnothing (alg. nlsolve )
108+ if isdefined (integrator. cache, :nlsolver ) && ! isnothing (get_nlsolve ( alg) )
81109 # backward Euler
82110 nlsolver = integrator. cache. nlsolver
83111 oldγ, oldc, oldmethod,
@@ -149,7 +177,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
149177 jac_prototype = f. jac_prototype,
150178 jac = jac)
151179 nlprob = NonlinearProblem (nlfunc, integrator. u, p)
152- nlsolve = default_nlsolve (alg. nlsolve , isinplace, u0, nlprob, isAD)
180+ nlsolve = default_nlsolve (get_nlsolve ( alg) , isinplace, u0, nlprob, isAD)
153181 nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
154182 reltol = integrator. opts. reltol)
155183 integrator. u .= nlsol. u
@@ -168,18 +196,19 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
168196 return
169197end
170198
171- function _initialize_dae! (integrator, prob:: ODEProblem , alg:: ShampineCollocationInit ,
199+ function _initialize_dae! (integrator, prob:: ODEProblem , alg:: Union{ ShampineCollocationInit, OrdinaryDiffEqCore.ShampineCollocationInitExt} ,
172200 isinplace:: Val{false} )
173201 @unpack p, t, f = integrator
174202 u0 = integrator. u
175203 M = integrator. f. mass_matrix
176204 dtmax = integrator. opts. dtmax
177205
178- dt = if alg. initdt === nothing
206+ initdt = get_initdt (alg)
207+ dt = if initdt === nothing
179208 integrator. dt != 0 ? min (integrator. dt / 5 , dtmax) :
180209 (prob. tspan[end ] - prob. tspan[begin ]) / 1000 # Haven't implemented norm reduction
181210 else
182- alg . initdt
211+ initdt
183212 end
184213
185214 algebraic_vars = [all (iszero, x) for x in eachcol (M)]
@@ -191,7 +220,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
191220
192221 check_dae_tolerance (integrator, resid, integrator. opts. abstol, t, isinplace) && return
193222
194- if isdefined (integrator. cache, :nlsolver ) && ! isnothing (alg. nlsolve )
223+ if isdefined (integrator. cache, :nlsolver ) && ! isnothing (get_nlsolve ( alg) )
195224 # backward Euler
196225 nlsolver = integrator. cache. nlsolver
197226 oldγ, oldc, oldmethod,
@@ -225,7 +254,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
225254 jac_prototype = f. jac_prototype,
226255 jac = jac)
227256 nlprob = NonlinearProblem (nlfunc, u0)
228- nlsolve = default_nlsolve (alg. nlsolve , isinplace, nlprob, u0)
257+ nlsolve = default_nlsolve (get_nlsolve ( alg) , isinplace, nlprob, u0)
229258
230259 nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
231260 reltol = integrator. opts. reltol)
@@ -306,7 +335,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
306335 jac_prototype = f. jac_prototype,
307336 jac = jac)
308337 nlprob = NonlinearProblem (nlfunc, u0, p)
309- nlsolve = default_nlsolve (alg. nlsolve , isinplace, u0, nlprob, isAD)
338+ nlsolve = default_nlsolve (get_nlsolve ( alg) , isinplace, u0, nlprob, isAD)
310339 nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
311340 reltol = integrator. opts. reltol)
312341
@@ -350,7 +379,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
350379 nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype,
351380 jac = jac)
352381 nlprob = NonlinearProblem (nlfunc, u0)
353- nlsolve = default_nlsolve (alg. nlsolve , isinplace, nlprob, u0)
382+ nlsolve = default_nlsolve (get_nlsolve ( alg) , isinplace, nlprob, u0)
354383
355384 nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype)
356385 nlprob = NonlinearProblem (nlfunc, u0)
@@ -388,7 +417,7 @@ function algebraic_jacobian(jac_prototype::T, algebraic_eqs,
388417end
389418
390419function _initialize_dae! (integrator, prob:: ODEProblem ,
391- alg:: BrownFullBasicInit , isinplace:: Val{true} )
420+ alg:: Union{ BrownFullBasicInit, DiffEqBase.BrownBasicInit} , isinplace:: Val{true} )
392421 @unpack p, t, f = integrator
393422 u = integrator. u
394423 M = integrator. f. mass_matrix
@@ -403,7 +432,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
403432
404433 tmp .= ArrayInterface. restructure (tmp, algebraic_eqs .* _vec (tmp))
405434
406- check_dae_tolerance (integrator, tmp, alg. abstol , t, isinplace) && return
435+ check_dae_tolerance (integrator, tmp, get_abstol ( alg, integrator) , t, isinplace) && return
407436 alg_u = @view u[algebraic_vars]
408437
409438 # These non-dual values are thus used to make the caches
@@ -451,9 +480,9 @@ function _initialize_dae!(integrator, prob::ODEProblem,
451480 J = algebraic_jacobian (f. jac_prototype, algebraic_eqs, algebraic_vars)
452481 nlfunc = NonlinearFunction (nlequation!; jac_prototype = J)
453482 nlprob = NonlinearProblem (nlfunc, alg_u, p)
454- nlsolve = default_nlsolve (alg. nlsolve , isinplace, u, nlprob, isAD)
483+ nlsolve = default_nlsolve (get_nlsolve ( alg) , isinplace, u, nlprob, isAD)
455484
456- nlsol = solve (nlprob, nlsolve; abstol = alg. abstol , reltol = integrator. opts. reltol)
485+ nlsol = solve (nlprob, nlsolve; abstol = get_abstol ( alg, integrator) , reltol = integrator. opts. reltol)
457486 alg_u .= nlsol
458487
459488 recursivecopy! (integrator. uprev, integrator. u)
@@ -469,7 +498,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
469498end
470499
471500function _initialize_dae! (integrator, prob:: ODEProblem ,
472- alg:: BrownFullBasicInit , isinplace:: Val{false} )
501+ alg:: Union{ BrownFullBasicInit, DiffEqBase.BrownBasicInit} , isinplace:: Val{false} )
473502 @unpack p, t, f = integrator
474503
475504 u0 = integrator. u
@@ -482,7 +511,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
482511 du = f (u0, p, t)
483512 resid = _vec (du)[algebraic_eqs]
484513
485- check_dae_tolerance (integrator, resid, alg. abstol , t, isinplace) && return
514+ check_dae_tolerance (integrator, resid, get_abstol ( alg, integrator) , t, isinplace) && return
486515
487516 isAD = alg_autodiff (integrator. alg) isa AutoForwardDiff
488517 if isAD
@@ -511,7 +540,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
511540 J = algebraic_jacobian (f. jac_prototype, algebraic_eqs, algebraic_vars)
512541 nlfunc = NonlinearFunction (nlequation; jac_prototype = J)
513542 nlprob = NonlinearProblem (nlfunc, u0[algebraic_vars])
514- nlsolve = default_nlsolve (alg. nlsolve , isinplace, u0, nlprob, isAD)
543+ nlsolve = default_nlsolve (get_nlsolve ( alg) , isinplace, u0, nlprob, isAD)
515544
516545 nlsol = solve (nlprob, nlsolve)
517546
@@ -537,7 +566,7 @@ function _initialize_dae!(integrator, prob::ODEProblem,
537566end
538567
539568function _initialize_dae! (integrator, prob:: DAEProblem ,
540- alg:: BrownFullBasicInit , isinplace:: Val{true} )
569+ alg:: Union{ BrownFullBasicInit, DiffEqBase.BrownBasicInit} , isinplace:: Val{true} )
541570 @unpack p, t, f = integrator
542571 differential_vars = prob. differential_vars
543572 u = integrator. u
@@ -561,7 +590,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
561590 normtmp = get_tmp_cache (integrator)[1 ]
562591 f (normtmp, du, u, p, t)
563592
564- if check_dae_tolerance (integrator, normtmp, alg. abstol , t, isinplace)
593+ if check_dae_tolerance (integrator, normtmp, get_abstol ( alg, integrator) , t, isinplace)
565594 return
566595 elseif differential_vars === nothing
567596 error (" differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm." )
@@ -592,15 +621,16 @@ function _initialize_dae!(integrator, prob::DAEProblem,
592621 f (out, du_tmp, uu, p, t)
593622 end
594623
595- if alg. nlsolve != = nothing
596- nlsolve = alg. nlsolve
624+ nlsolve_alg = get_nlsolve (alg)
625+ if nlsolve_alg != = nothing
626+ nlsolve = nlsolve_alg
597627 else
598628 nlsolve = NewtonRaphson (autodiff = alg_autodiff (integrator. alg))
599629 end
600630
601631 nlfunc = NonlinearFunction (nlequation!; jac_prototype = f. jac_prototype)
602632 nlprob = NonlinearProblem (nlfunc, ifelse .(differential_vars, du, u), p)
603- nlsol = solve (nlprob, nlsolve; abstol = alg. abstol , reltol = integrator. opts. reltol)
633+ nlsol = solve (nlprob, nlsolve; abstol = get_abstol ( alg, integrator) , reltol = integrator. opts. reltol)
604634
605635 @. du = ifelse (differential_vars, nlsol. u, du)
606636 @. u = ifelse (differential_vars, u, nlsol. u)
@@ -618,12 +648,12 @@ function _initialize_dae!(integrator, prob::DAEProblem,
618648end
619649
620650function _initialize_dae! (integrator, prob:: DAEProblem ,
621- alg:: BrownFullBasicInit , isinplace:: Val{false} )
651+ alg:: Union{ BrownFullBasicInit, DiffEqBase.BrownBasicInit} , isinplace:: Val{false} )
622652 @unpack p, t, f = integrator
623653 differential_vars = prob. differential_vars
624654
625655 if check_dae_tolerance (
626- integrator, f (integrator. du, integrator. u, p, t), alg. abstol , t, isinplace)
656+ integrator, f (integrator. du, integrator. u, p, t), get_abstol ( alg, integrator) , t, isinplace)
627657 return
628658 elseif differential_vars === nothing
629659 error (" differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions, differential_vars, or use a different initialization algorithm." )
@@ -647,7 +677,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
647677 nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype)
648678 nlprob = NonlinearProblem (nlfunc, ifelse .(differential_vars, du, u))
649679
650- nlsolve = default_nlsolve (alg. nlsolve , isinplace, nlprob, integrator. u)
680+ nlsolve = default_nlsolve (get_nlsolve ( alg) , isinplace, nlprob, integrator. u)
651681
652682 @show nlsolve
653683
0 commit comments