Skip to content

Commit 03b2946

Browse files
Merge pull request #953 from AayushSabharwal/as/override-init-kwargs
feat: forward kwargs to `solve` in `OverrideInit`
2 parents f5b6ba7 + ab41d7a commit 03b2946

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

src/initialization.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,10 @@ Keyword arguments:
216216
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
217217
an error will be thrown.
218218
219+
All additional keyword arguments are forwarded to `solve`.
220+
219221
In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are
220-
not required.
222+
not required. `solve` is also not called.
221223
"""
222224
function get_initial_values(prob, valp, f, alg::OverrideInit,
223225
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
@@ -257,7 +259,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
257259
else
258260
throw(OverrideInitNoTolerance(:reltol))
259261
end
260-
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
262+
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol, kwargs...)
261263
success = SciMLBase.successful_retcode(nlsol)
262264
end
263265

@@ -304,7 +306,8 @@ function initialization_status(prob::AbstractSciMLProblem)
304306
iprob = prob.f.initialization_data.initializeprob
305307
isnothing(prob) && return nothing
306308

307-
nunknowns = iprob.u0 === nothing ? 0 : length(iprob.u0)
309+
iu0 = state_values(iprob)
310+
nunknowns = iu0 === nothing ? 0 : length(iu0)
308311
neqs = if __has_resid_prototype(iprob.f) && iprob.f.resid_prototype !== nothing
309312
length(iprob.f.resid_prototype)
310313
else

test/initialization.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ end
203203
@test p 1.0
204204
@test success
205205
end
206+
@testset "with kwargs provided to `get_initial_values`" begin
207+
u0, p, success = SciMLBase.get_initial_values(
208+
prob, integ, fn, SciMLBase.OverrideInit(),
209+
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol, u0 = [-1.0, 1.0])
210+
@test u0 [2.0, -2.0]
211+
@test p 1.0
212+
@test success
213+
end
206214
end
207215

208216
@testset "Solves with non-integrator value provider" begin
@@ -262,6 +270,15 @@ end
262270
@test success
263271
end
264272

273+
@testset "Initialization status for `SCCNonlinearProblem`" begin
274+
initprob = SCCNonlinearProblem([initprob], [Returns(nothing)])
275+
initialization_data = SciMLBase.OverrideInitData(
276+
initprob, nothing, nothing, nothing)
277+
fn = ODEFunction(rhs2; initialization_data)
278+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
279+
@test SciMLBase.initialization_status(prob) == SciMLBase.FULLY_DETERMINED
280+
end
281+
265282
@testset "Trivial initialization" begin
266283
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
267284
update_initializeprob! = function (iprob, integ)

0 commit comments

Comments
 (0)