Skip to content

Commit 924e92b

Browse files
feat: add override_init_get_nlsolve
1 parent c61b13d commit 924e92b

File tree

3 files changed

+95
-14
lines changed

3 files changed

+95
-14
lines changed

src/SciMLBase.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import CommonSolve: solve, init, step!, solve!
2121
import FunctionWrappersWrappers
2222
import RuntimeGeneratedFunctions
2323
import EnumX
24-
import ADTypes: AbstractADType
24+
import ADTypes: ADTypes, AbstractADType
2525
import Accessors: @set, @reset
2626
using Expronicon.ADT: @match
2727

@@ -351,7 +351,15 @@ struct CheckInit <: DAEInitializationAlgorithm end
351351
"""
352352
$(TYPEDEF)
353353
"""
354-
struct OverrideInit <: DAEInitializationAlgorithm end
354+
struct OverrideInit{T, F} <: DAEInitializationAlgorithm
355+
abstol::T
356+
nlsolve::F
357+
end
358+
359+
function OverrideInit(; abstol = 1e-10, nlsolve = nothing)
360+
OverrideInit(abstol, nlsolve)
361+
end
362+
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)
355363

356364
# PDE Discretizations
357365

src/initialization.jl

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,51 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
6868
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
6969
end
7070

71+
struct NoNonlinearSolverError <: Exception
72+
end
73+
74+
function Base.showerror(io::IO, err::NoNonlinearSolverError)
75+
println(io, """
76+
This problem requires initialization and thus a nonlinear solve, but no nonlinear \
77+
solve has been loaded. If you are using OrdinaryDiffEq, import the \
78+
`OrdinaryDiffEqNonlinearSolve` package or pass a custom `nlsolve` into the \
79+
`initializealg`. If you are not using `OrdinaryDiffEq`, please open an issue in \
80+
the appropriate library with an MWE.
81+
""")
82+
end
83+
84+
"""
85+
$(TYPEDSIGNATURES)
86+
87+
Given a user-provided nonlinear solve algorithm `alg`, `iip::Union{Val{true}, Val{false}}`
88+
indicating whether the initialization problem is in-place or not, the initial state
89+
vector of the initialization problem, the initialization problem (either a
90+
`NonlinearProblem` or `NonlinearLeastSquaresProblem`) and a boolean `autodiff`
91+
indicating whether to use `AutoForwardDiff` or `AutoFiniteDiff`, return a nonlinear
92+
solve algorithm to use for solving the initialization. If `alg` is not nothing, it will
93+
be returned as-is. If the initialization problem is trivial (`u === nothing`) the trivial
94+
`nothing` algorithm will be used. Otherwise, requires `NonlinearSolve.jl` to
95+
automatically find an appropriate solver.
96+
"""
97+
override_init_get_nlsolve(alg, iip, u, prob, autodiff = false) = alg
98+
99+
for iip in (Val{true}, Val{false}), prob in (NonlinearProblem, NonlinearLeastSquaresProblem)
100+
@eval function override_init_get_nlsolve(
101+
::Nothing, ::$(iip), u::Nothing, ::$(prob), autodiff = false)
102+
nothing
103+
end
104+
end
105+
106+
function override_init_get_nlsolve(
107+
::Nothing, isinplace, u, initializeprob::NonlinearProblem, autodiff = false)
108+
throw(NoNonlinearSolverError())
109+
end
110+
111+
function override_init_get_nlsolve(
112+
::Nothing, isinplace, u, initializeprob::NonlinearLeastSquaresProblem, autodiff = false)
113+
throw(NoNonlinearSolverError())
114+
end
115+
71116
"""
72117
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
73118
it is in-place or simply calling the function if not.
@@ -160,7 +205,7 @@ argument, failing which this function will throw an error. The success value ret
160205
depends on the success of the nonlinear solve.
161206
"""
162207
function get_initial_values(prob, valp, f, alg::OverrideInit,
163-
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
208+
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, autodiff = false, kwargs...)
164209
u0 = state_values(valp)
165210
p = parameter_values(valp)
166211

@@ -171,9 +216,9 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
171216
initdata::OverrideInitData = f.initialization_data
172217
initprob = initdata.initializeprob
173218

174-
if nlsolve_alg === nothing
175-
throw(OverrideInitMissingAlgorithm())
176-
end
219+
nlsolve_alg = override_init_get_nlsolve(
220+
something(nlsolve_alg, alg.nlsolve, Some(nothing)),
221+
Val{isinplace(initprob)}(), state_values(initprob), initprob, autodiff)
177222

178223
if initdata.update_initializeprob! !== nothing
179224
initdata.update_initializeprob!(initprob, valp)

test/initialization.jl

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,48 @@ end
9696
integ = init(prob; initializealg = NoInit())
9797

9898
@testset "Errors without `nlsolve_alg`" begin
99-
@test_throws SciMLBase.OverrideInitMissingAlgorithm SciMLBase.get_initial_values(
99+
@test_throws SciMLBase.NoNonlinearSolverError SciMLBase.get_initial_values(
100100
prob, integ, fn, SciMLBase.OverrideInit(), Val(false))
101101
end
102102

103103
@testset "Solves" begin
104-
u0, p, success = SciMLBase.get_initial_values(
105-
prob, integ, fn, SciMLBase.OverrideInit(),
106-
Val(false); nlsolve_alg = NewtonRaphson())
104+
@testset "with explicit alg" begin
105+
u0, p, success = SciMLBase.get_initial_values(
106+
prob, integ, fn, SciMLBase.OverrideInit(),
107+
Val(false); nlsolve_alg = NewtonRaphson())
107108

108-
@test u0 [2.0, 2.0]
109-
@test p 1.0
110-
@test success
109+
@test u0 [2.0, 2.0]
110+
@test p 1.0
111+
@test success
111112

112-
initprob.p[1] = 1.0
113+
initprob.p[1] = 1.0
114+
end
115+
@testset "with alg in `OverrideInit`" begin
116+
u0, p, success = SciMLBase.get_initial_values(
117+
prob, integ, fn, SciMLBase.OverrideInit(nlsolve = NewtonRaphson()),
118+
Val(false))
119+
120+
@test u0 [2.0, 2.0]
121+
@test p 1.0
122+
@test success
123+
124+
initprob.p[1] = 1.0
125+
end
126+
@testset "with trivial problem and no alg" begin
127+
iprob = NonlinearProblem((u, p) -> 0.0, nothing, 1.0)
128+
iprobmap = (_) -> [1.0, 1.0]
129+
initdata = SciMLBase.OverrideInitData(iprob, nothing, iprobmap, nothing)
130+
_fn = ODEFunction(rhs2; initialization_data = initdata)
131+
_prob = ODEProblem(_fn, [2.0, 0.0], (0.0, 1.0), 1.0)
132+
_integ = init(_prob; initializealg = NoInit())
133+
134+
u0, p, success = SciMLBase.get_initial_values(
135+
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false))
136+
137+
@test u0 [1.0, 1.0]
138+
@test p 1.0
139+
@test success
140+
end
113141
end
114142

115143
@testset "Solves with non-integrator value provider" begin

0 commit comments

Comments
 (0)