Skip to content

Commit 0a66033

Browse files
refactor: generalize _initialize_dae! to use SciMLBase implementations
1 parent ad7891e commit 0a66033

File tree

2 files changed

+18
-107
lines changed

2 files changed

+18
-107
lines changed

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ using DiffEqBase: check_error!, @def, _vec, _reshape
6060

6161
using FastBroadcast: @.., True, False
6262

63-
using SciMLBase: NoInit, CheckInit, _unwrap_val
63+
using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val
6464

6565
import SciMLBase: alg_order
6666

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 17 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,6 @@ function BrownFullBasicInit(; abstol = 1e-10, nlsolve = nothing)
2020
end
2121
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)
2222

23-
struct OverrideInit{T, F} <: DiffEqBase.DAEInitializationAlgorithm
24-
abstol::T
25-
nlsolve::F
26-
end
27-
28-
function OverrideInit(; abstol = 1e-10, nlsolve = nothing)
29-
OverrideInit(abstol, nlsolve)
30-
end
31-
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)
32-
3323
## Notes
3424

3525
#=
@@ -143,19 +133,15 @@ end
143133

144134
## NoInit
145135

146-
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
136+
function _initialize_dae!(integrator, prob::AbstractDEProblem,
147137
alg::NoInit, x::Union{Val{true}, Val{false}})
148138
end
149139

150140
## OverrideInit
151141

152-
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
142+
function _initialize_dae!(integrator, prob::AbstractDEProblem,
153143
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
154-
initializeprob = prob.f.initializeprob
155-
156-
if SciMLBase.has_update_initializeprob!(prob.f)
157-
prob.f.update_initializeprob!(initializeprob, prob)
158-
end
144+
initializeprob = prob.f.initialization_data.initializeprob
159145

160146
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
161147
# Since then it's the case of not a DAE but has initializeprob
@@ -168,105 +154,30 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
168154
true
169155
end
170156

171-
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
172-
nlsol = solve(initializeprob, alg)
157+
nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
158+
159+
u0, p, success = SciMLBase.get_initial_values(prob, prob.f, integrator, alg, isinplace; nlsolve_alg)
160+
173161
if isinplace === Val{true}()
174-
integrator.u .= prob.f.initializeprobmap(nlsol)
162+
integrator.u .= u0
175163
elseif isinplace === Val{false}()
176-
integrator.u = prob.f.initializeprobmap(nlsol)
164+
integrator.u = u0
177165
else
178166
error("Unreachable reached. Report this error.")
179167
end
180-
if SciMLBase.has_initializeprobpmap(prob.f)
181-
integrator.p = prob.f.initializeprobpmap(prob, nlsol)
182-
sol = integrator.sol
183-
@reset sol.prob.p = integrator.p
184-
integrator.sol = sol
185-
end
168+
integrator.p = p
169+
sol = integrator.sol
170+
@reset sol.prob.p = integrator.p
171+
integrator.sol = sol
186172

187-
if nlsol.retcode != ReturnCode.Success
173+
if !success
188174
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
189175
ReturnCode.InitialFailure)
190176
end
191177
end
192178

193179
## CheckInit
194-
struct CheckInitFailureError <: Exception
195-
normresid::Any
196-
abstol::Any
197-
end
198-
199-
function Base.showerror(io::IO, e::CheckInitFailureError)
200-
print(io,
201-
"CheckInit specified but initialization not satisifed. normresid = $(e.normresid) > abstol = $(e.abstol)")
202-
end
203-
204-
function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
205-
isinplace::Val{true})
206-
@unpack p, t, f = integrator
207-
M = integrator.f.mass_matrix
208-
tmp = first(get_tmp_cache(integrator))
209-
u0 = integrator.u
210-
211-
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
212-
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
213-
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
214-
update_coefficients!(M, u0, p, t)
215-
f(tmp, u0, p, t)
216-
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
217-
218-
normresid = integrator.opts.internalnorm(tmp, t)
219-
if normresid > integrator.opts.abstol
220-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
221-
end
222-
end
223-
224-
function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
225-
isinplace::Val{false})
226-
@unpack p, t, f = integrator
227-
u0 = integrator.u
228-
M = integrator.f.mass_matrix
229-
230-
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
231-
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
232-
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
233-
update_coefficients!(M, u0, p, t)
234-
du = f(u0, p, t)
235-
resid = _vec(du)[algebraic_eqs]
236-
237-
normresid = integrator.opts.internalnorm(resid, t)
238-
if normresid > integrator.opts.abstol
239-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
240-
end
241-
end
242-
243-
function _initialize_dae!(integrator, prob::DAEProblem,
244-
alg::CheckInit, isinplace::Val{true})
245-
@unpack p, t, f = integrator
246-
u0 = integrator.u
247-
resid = get_tmp_cache(integrator)[2]
248-
249-
f(resid, integrator.du, u0, p, t)
250-
normresid = integrator.opts.internalnorm(resid, t)
251-
if normresid > integrator.opts.abstol
252-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
253-
end
254-
end
255-
256-
function _initialize_dae!(integrator, prob::DAEProblem,
257-
alg::CheckInit, isinplace::Val{false})
258-
@unpack p, t, f = integrator
259-
u0 = integrator.u
260-
261-
nlequation_oop = u -> begin
262-
f((u - u0) / dt, u, p, t)
263-
end
264-
265-
nlequation = (u, _) -> nlequation_oop(u)
266-
267-
resid = f(integrator.du, u0, p, t)
268-
normresid = integrator.opts.internalnorm(resid, t)
269-
if normresid > integrator.opts.abstol
270-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
271-
end
180+
function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::CheckInit,
181+
isinplace::Union{Val{true}, Val{false}})
182+
SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace)
272183
end

0 commit comments

Comments
 (0)