Skip to content

Commit a4fd6d8

Browse files
feat: do not require nlsolve_alg for trivial OverrideInit
1 parent 86aa145 commit a4fd6d8

File tree

2 files changed

+53
-19
lines changed

2 files changed

+53
-19
lines changed

src/initialization.jl

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ Keyword arguments:
172172
provided to the `OverrideInit` constructor takes priority over this keyword argument.
173173
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
174174
an error will be thrown.
175+
176+
In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are
177+
not required.
175178
"""
176179
function get_initial_values(prob, valp, f, alg::OverrideInit,
177180
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
@@ -185,35 +188,40 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
185188
initdata::OverrideInitData = f.initialization_data
186189
initprob = initdata.initializeprob
187190

188-
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
189-
if nlsolve_alg === nothing && state_values(initprob) !== nothing
190-
throw(OverrideInitMissingAlgorithm())
191-
end
192-
193191
if initdata.update_initializeprob! !== nothing
194192
initdata.update_initializeprob!(initprob, valp)
195193
end
196194

197-
if alg.abstol !== nothing
198-
_abstol = alg.abstol
199-
elseif abstol !== nothing
200-
_abstol = abstol
201-
else
202-
throw(OverrideInitNoTolerance(:abstol))
203-
end
204-
if alg.reltol !== nothing
205-
_reltol = alg.reltol
206-
elseif reltol !== nothing
207-
_reltol = reltol
195+
if state_values(initprob) === nothing
196+
nlsol = initprob
197+
success = true
208198
else
209-
throw(OverrideInitNoTolerance(:reltol))
199+
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
200+
if nlsolve_alg === nothing && state_values(initprob) !== nothing
201+
throw(OverrideInitMissingAlgorithm())
202+
end
203+
if alg.abstol !== nothing
204+
_abstol = alg.abstol
205+
elseif abstol !== nothing
206+
_abstol = abstol
207+
else
208+
throw(OverrideInitNoTolerance(:abstol))
209+
end
210+
if alg.reltol !== nothing
211+
_reltol = alg.reltol
212+
elseif reltol !== nothing
213+
_reltol = reltol
214+
else
215+
throw(OverrideInitNoTolerance(:reltol))
216+
end
217+
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
218+
success = SciMLBase.successful_retcode(nlsol)
210219
end
211-
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
212220

213221
u0 = initdata.initializeprobmap(nlsol)
214222
if initdata.initializeprobpmap !== nothing
215223
p = initdata.initializeprobpmap(valp, nlsol)
216224
end
217225

218-
return u0, p, SciMLBase.successful_retcode(nlsol)
226+
return u0, p, success
219227
end

test/initialization.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,30 @@ end
244244
@test p 0.0
245245
@test success
246246
end
247+
248+
@testset "Trivial initialization" begin
249+
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
250+
update_initializeprob! = function (iprob, integ)
251+
iprob.p[1] = integ.u[1]
252+
end
253+
initprobmap = function (nlsol)
254+
u1 = parameter_values(nlsol)[1]
255+
return [u1, u1]
256+
end
257+
initprobpmap = function (_, nlsol)
258+
return 0.0
259+
end
260+
initialization_data = SciMLBase.OverrideInitData(
261+
initprob, update_initializeprob!, initprobmap, initprobpmap)
262+
fn = ODEFunction(rhs2; initialization_data)
263+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
264+
integ = init(prob; initializealg = NoInit())
265+
266+
u0, p, success = SciMLBase.get_initial_values(
267+
prob, integ, fn, SciMLBase.OverrideInit(), Val(false)
268+
)
269+
@test u0 [2.0, 2.0]
270+
@test p 0.0
271+
@test success
272+
end
247273
end

0 commit comments

Comments
 (0)