Skip to content

Commit ded6aa8

Browse files
feat: do not require nlsolve_alg for trivial OverrideInit
1 parent 598c7cd commit ded6aa8

File tree

2 files changed

+53
-19
lines changed

2 files changed

+53
-19
lines changed

src/initialization.jl

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ Keyword arguments:
188188
provided to the `OverrideInit` constructor takes priority over this keyword argument.
189189
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
190190
an error will be thrown.
191+
192+
In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are
193+
not required.
191194
"""
192195
function get_initial_values(prob, valp, f, alg::OverrideInit,
193196
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
@@ -201,35 +204,40 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
201204
initdata::OverrideInitData = f.initialization_data
202205
initprob = initdata.initializeprob
203206

204-
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
205-
if nlsolve_alg === nothing && state_values(initprob) !== nothing
206-
throw(OverrideInitMissingAlgorithm())
207-
end
208-
209207
if initdata.update_initializeprob! !== nothing
210208
initdata.update_initializeprob!(initprob, valp)
211209
end
212210

213-
if alg.abstol !== nothing
214-
_abstol = alg.abstol
215-
elseif abstol !== nothing
216-
_abstol = abstol
217-
else
218-
throw(OverrideInitNoTolerance(:abstol))
219-
end
220-
if alg.reltol !== nothing
221-
_reltol = alg.reltol
222-
elseif reltol !== nothing
223-
_reltol = reltol
211+
if state_values(initprob) === nothing
212+
nlsol = initprob
213+
success = true
224214
else
225-
throw(OverrideInitNoTolerance(:reltol))
215+
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
216+
if nlsolve_alg === nothing && state_values(initprob) !== nothing
217+
throw(OverrideInitMissingAlgorithm())
218+
end
219+
if alg.abstol !== nothing
220+
_abstol = alg.abstol
221+
elseif abstol !== nothing
222+
_abstol = abstol
223+
else
224+
throw(OverrideInitNoTolerance(:abstol))
225+
end
226+
if alg.reltol !== nothing
227+
_reltol = alg.reltol
228+
elseif reltol !== nothing
229+
_reltol = reltol
230+
else
231+
throw(OverrideInitNoTolerance(:reltol))
232+
end
233+
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
234+
success = SciMLBase.successful_retcode(nlsol)
226235
end
227-
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
228236

229237
u0 = initdata.initializeprobmap(nlsol)
230238
if initdata.initializeprobpmap !== nothing
231239
p = initdata.initializeprobpmap(valp, nlsol)
232240
end
233241

234-
return u0, p, SciMLBase.successful_retcode(nlsol)
242+
return u0, p, success
235243
end

test/initialization.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,30 @@ end
244244
@test p 0.0
245245
@test success
246246
end
247+
248+
@testset "Trivial initialization" begin
249+
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
250+
update_initializeprob! = function (iprob, integ)
251+
iprob.p[1] = integ.u[1]
252+
end
253+
initprobmap = function (nlsol)
254+
u1 = parameter_values(nlsol)[1]
255+
return [u1, u1]
256+
end
257+
initprobpmap = function (_, nlsol)
258+
return 0.0
259+
end
260+
initialization_data = SciMLBase.OverrideInitData(
261+
initprob, update_initializeprob!, initprobmap, initprobpmap)
262+
fn = ODEFunction(rhs2; initialization_data)
263+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
264+
integ = init(prob; initializealg = NoInit())
265+
266+
u0, p, success = SciMLBase.get_initial_values(
267+
prob, integ, fn, SciMLBase.OverrideInit(), Val(false)
268+
)
269+
@test u0 [2.0, 2.0]
270+
@test p 0.0
271+
@test success
272+
end
247273
end

0 commit comments

Comments
 (0)