Skip to content

Commit a6a4a1f

Browse files
feat: do not require nlsolve_alg for trivial OverrideInit
1 parent a5ee8e9 commit a6a4a1f

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

src/initialization.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ Keyword arguments:
171171
provided to the `OverrideInit` constructor takes priority over this keyword argument.
172172
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
173173
an error will be thrown.
174+
175+
In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are
176+
not required.
174177
"""
175178
function get_initial_values(prob, valp, f, alg::OverrideInit,
176179
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
@@ -193,26 +196,32 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
193196
initdata.update_initializeprob!(initprob, valp)
194197
end
195198

196-
if alg.abstol !== nothing
197-
_abstol = alg.abstol
198-
elseif abstol !== nothing
199-
_abstol = abstol
200-
else
201-
throw(OverrideInitNoTolerance(:abstol))
202-
end
203-
if alg.reltol !== nothing
204-
_reltol = alg.reltol
205-
elseif reltol !== nothing
206-
_reltol = reltol
199+
if state_values(initprob) === nothing
200+
nlsol = initprob
201+
success = true
207202
else
208-
throw(OverrideInitNoTolerance(:reltol))
203+
if alg.abstol !== nothing
204+
_abstol = alg.abstol
205+
elseif abstol !== nothing
206+
_abstol = abstol
207+
else
208+
throw(OverrideInitNoTolerance(:abstol))
209+
end
210+
if alg.reltol !== nothing
211+
_reltol = alg.reltol
212+
elseif reltol !== nothing
213+
_reltol = reltol
214+
else
215+
throw(OverrideInitNoTolerance(:reltol))
216+
end
217+
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
218+
success = SciMLBase.successful_retcode(nlsol)
209219
end
210-
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
211220

212221
u0 = initdata.initializeprobmap(nlsol)
213222
if initdata.initializeprobpmap !== nothing
214223
p = initdata.initializeprobpmap(valp, nlsol)
215224
end
216225

217-
return u0, p, SciMLBase.successful_retcode(nlsol)
226+
return u0, p, success
218227
end

test/initialization.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,30 @@ end
229229
@test p 0.0
230230
@test success
231231
end
232+
233+
@testset "Trivial initialization" begin
234+
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
235+
update_initializeprob! = function (iprob, integ)
236+
iprob.p[1] = integ.u[1]
237+
end
238+
initprobmap = function (nlsol)
239+
u1 = parameter_values(nlsol)[1]
240+
return [u1, u1]
241+
end
242+
initprobpmap = function (_, nlsol)
243+
return 0.0
244+
end
245+
initialization_data = SciMLBase.OverrideInitData(
246+
initprob, update_initializeprob!, initprobmap, initprobpmap)
247+
fn = ODEFunction(rhs2; initialization_data)
248+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
249+
integ = init(prob; initializealg = NoInit())
250+
251+
u0, p, success = SciMLBase.get_initial_values(
252+
prob, integ, fn, SciMLBase.OverrideInit(), Val(false)
253+
)
254+
@test u0 [2.0, 2.0]
255+
@test p 0.0
256+
@test success
257+
end
232258
end

0 commit comments

Comments
 (0)