Skip to content

Commit 4ecb539

Browse files
Merge pull request #2521 from AayushSabharwal/as/rework-init
refactor: generalize `_initialize_dae!` to use SciMLBase implementations
2 parents 9ab927b + c7b28ff commit 4ecb539

File tree

5 files changed

+24
-113
lines changed

5 files changed

+24
-113
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ jobs:
9595
with:
9696
token: ${{ secrets.CODECOV_TOKEN }}
9797
file: lcov.info
98-
fail_ci_if_error: true
98+
fail_ci_if_error: false

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Random = "<0.0.1, 1"
7070
RecursiveArrayTools = "2.36, 3"
7171
Reexport = "1.0"
7272
SafeTestsets = "0.1.0"
73-
SciMLBase = "2.60"
73+
SciMLBase = "2.62"
7474
SciMLOperators = "0.3"
7575
SciMLStructures = "1"
7676
SimpleUnPack = "1"

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ using DiffEqBase: check_error!, @def, _vec, _reshape
6060

6161
using FastBroadcast: @.., True, False
6262

63-
using SciMLBase: NoInit, CheckInit, _unwrap_val
63+
using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val
6464

6565
import SciMLBase: alg_order
6666

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 17 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,6 @@ function BrownFullBasicInit(; abstol = 1e-10, nlsolve = nothing)
2020
end
2121
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)
2222

23-
struct OverrideInit{T, F} <: DiffEqBase.DAEInitializationAlgorithm
24-
abstol::T
25-
nlsolve::F
26-
end
27-
28-
function OverrideInit(; abstol = 1e-10, nlsolve = nothing)
29-
OverrideInit(abstol, nlsolve)
30-
end
31-
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)
32-
3323
## Notes
3424

3525
#=
@@ -143,19 +133,15 @@ end
143133

144134
## NoInit
145135

146-
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
136+
function _initialize_dae!(integrator, prob::AbstractDEProblem,
147137
alg::NoInit, x::Union{Val{true}, Val{false}})
148138
end
149139

150140
## OverrideInit
151141

152-
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
142+
function _initialize_dae!(integrator, prob::AbstractDEProblem,
153143
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
154-
initializeprob = prob.f.initializeprob
155-
156-
if SciMLBase.has_update_initializeprob!(prob.f)
157-
prob.f.update_initializeprob!(initializeprob, prob)
158-
end
144+
initializeprob = prob.f.initialization_data.initializeprob
159145

160146
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
161147
# Since then it's the case of not a DAE but has initializeprob
@@ -168,105 +154,30 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
168154
true
169155
end
170156

171-
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
172-
nlsol = solve(initializeprob, alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)
157+
nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
158+
159+
u0, p, success = SciMLBase.get_initial_values(prob, prob.f, integrator, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)
160+
173161
if isinplace === Val{true}()
174-
integrator.u .= prob.f.initializeprobmap(nlsol)
162+
integrator.u .= u0
175163
elseif isinplace === Val{false}()
176-
integrator.u = prob.f.initializeprobmap(nlsol)
164+
integrator.u = u0
177165
else
178166
error("Unreachable reached. Report this error.")
179167
end
180-
if SciMLBase.has_initializeprobpmap(prob.f)
181-
integrator.p = prob.f.initializeprobpmap(prob, nlsol)
182-
sol = integrator.sol
183-
@reset sol.prob.p = integrator.p
184-
integrator.sol = sol
185-
end
168+
integrator.p = p
169+
sol = integrator.sol
170+
@reset sol.prob.p = integrator.p
171+
integrator.sol = sol
186172

187-
if nlsol.retcode != ReturnCode.Success
173+
if !success
188174
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
189175
ReturnCode.InitialFailure)
190176
end
191177
end
192178

193179
## CheckInit
194-
struct CheckInitFailureError <: Exception
195-
normresid::Any
196-
abstol::Any
197-
end
198-
199-
function Base.showerror(io::IO, e::CheckInitFailureError)
200-
print(io,
201-
"CheckInit specified but initialization not satisifed. normresid = $(e.normresid) > abstol = $(e.abstol)")
202-
end
203-
204-
function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
205-
isinplace::Val{true})
206-
@unpack p, t, f = integrator
207-
M = integrator.f.mass_matrix
208-
tmp = first(get_tmp_cache(integrator))
209-
u0 = integrator.u
210-
211-
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
212-
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
213-
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
214-
update_coefficients!(M, u0, p, t)
215-
f(tmp, u0, p, t)
216-
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
217-
218-
normresid = integrator.opts.internalnorm(tmp, t)
219-
if normresid > integrator.opts.abstol
220-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
221-
end
222-
end
223-
224-
function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
225-
isinplace::Val{false})
226-
@unpack p, t, f = integrator
227-
u0 = integrator.u
228-
M = integrator.f.mass_matrix
229-
230-
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
231-
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
232-
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
233-
update_coefficients!(M, u0, p, t)
234-
du = f(u0, p, t)
235-
resid = _vec(du)[algebraic_eqs]
236-
237-
normresid = integrator.opts.internalnorm(resid, t)
238-
if normresid > integrator.opts.abstol
239-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
240-
end
241-
end
242-
243-
function _initialize_dae!(integrator, prob::DAEProblem,
244-
alg::CheckInit, isinplace::Val{true})
245-
@unpack p, t, f = integrator
246-
u0 = integrator.u
247-
resid = get_tmp_cache(integrator)[2]
248-
249-
f(resid, integrator.du, u0, p, t)
250-
normresid = integrator.opts.internalnorm(resid, t)
251-
if normresid > integrator.opts.abstol
252-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
253-
end
254-
end
255-
256-
function _initialize_dae!(integrator, prob::DAEProblem,
257-
alg::CheckInit, isinplace::Val{false})
258-
@unpack p, t, f = integrator
259-
u0 = integrator.u
260-
261-
nlequation_oop = u -> begin
262-
f((u - u0) / dt, u, p, t)
263-
end
264-
265-
nlequation = (u, _) -> nlequation_oop(u)
266-
267-
resid = f(integrator.du, u0, p, t)
268-
normresid = integrator.opts.internalnorm(resid, t)
269-
if normresid > integrator.opts.abstol
270-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
271-
end
180+
function _initialize_dae!(integrator, prob::AbstractDEProblem, alg::CheckInit,
181+
isinplace::Union{Val{true}, Val{false}})
182+
SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; abstol = integrator.opts.abstol)
272183
end

test/interface/checkinit_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ roberf_oop = ODEFunction{false}(rober, mass_matrix = M)
2424
prob_mm = ODEProblem(roberf, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
2525
prob_mm_oop = ODEProblem(roberf_oop, [1.0, 0.0, 0.2], (0.0, 1e5), (0.04, 3e7, 1e4))
2626

27-
@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve(
27+
@test_throws SciMLBase.CheckInitFailureError solve(
2828
prob_mm, Rodas5P(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit())
29-
@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve(
29+
@test_throws SciMLBase.CheckInitFailureError solve(
3030
prob_mm_oop, Rodas5P(), reltol = 1e-8, abstol = 1e-8,
3131
initializealg = SciMLBase.CheckInit())
3232

@@ -49,7 +49,7 @@ tspan = (0.0, 100000.0)
4949
differential_vars = [true, true, false]
5050
prob = DAEProblem(f, du₀, u₀, tspan, differential_vars = differential_vars)
5151
prob_oop = DAEProblem(f_oop, du₀, u₀, tspan, differential_vars = differential_vars)
52-
@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve(
52+
@test_throws SciMLBase.CheckInitFailureError solve(
5353
prob, DFBDF(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit())
54-
@test_throws OrdinaryDiffEqCore.CheckInitFailureError solve(
54+
@test_throws SciMLBase.CheckInitFailureError solve(
5555
prob_oop, DFBDF(), reltol = 1e-8, abstol = 1e-8, initializealg = SciMLBase.CheckInit())

0 commit comments

Comments
 (0)