Skip to content

Commit e0fade7

Browse files
feat: add fields to OverrideInit, better nlsolve_alg handling
1 parent 169d419 commit e0fade7

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

src/SciMLBase.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import CommonSolve: solve, init, step!, solve!
2121
import FunctionWrappersWrappers
2222
import RuntimeGeneratedFunctions
2323
import EnumX
24-
import ADTypes: AbstractADType
24+
import ADTypes: ADTypes, AbstractADType
2525
import Accessors: @set, @reset
2626
using Expronicon.ADT: @match
2727

@@ -351,7 +351,15 @@ struct CheckInit <: DAEInitializationAlgorithm end
351351
"""
352352
$(TYPEDEF)
353353
"""
354-
struct OverrideInit <: DAEInitializationAlgorithm end
354+
struct OverrideInit{T, F} <: DAEInitializationAlgorithm
355+
abstol::T
356+
nlsolve::F
357+
end
358+
359+
function OverrideInit(; abstol = 1e-10, nlsolve = nothing)
360+
OverrideInit(abstol, nlsolve)
361+
end
362+
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)
355363

356364
# PDE Discretizations
357365

src/initialization.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ argument, failing which this function will throw an error. The success value ret
160160
depends on the success of the nonlinear solve.
161161
"""
162162
function get_initial_values(prob, valp, f, alg::OverrideInit,
163-
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
163+
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
164164
u0 = state_values(valp)
165165
p = parameter_values(valp)
166166

@@ -171,15 +171,16 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
171171
initdata::OverrideInitData = f.initialization_data
172172
initprob = initdata.initializeprob
173173

174-
if nlsolve_alg === nothing
174+
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
175+
if nlsolve_alg === nothing && state_values(initprob) !== nothing
175176
throw(OverrideInitMissingAlgorithm())
176177
end
177178

178179
if initdata.update_initializeprob! !== nothing
179180
initdata.update_initializeprob!(initprob, valp)
180181
end
181182

182-
nlsol = solve(initprob, nlsolve_alg)
183+
nlsol = solve(initprob, nlsolve_alg; abstol = alg.abstol)
183184

184185
u0 = initdata.initializeprobmap(nlsol)
185186
if initdata.initializeprobpmap !== nothing

test/initialization.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,43 @@ end
101101
end
102102

103103
@testset "Solves" begin
104-
u0, p, success = SciMLBase.get_initial_values(
105-
prob, integ, fn, SciMLBase.OverrideInit(),
106-
Val(false); nlsolve_alg = NewtonRaphson())
104+
@testset "with explicit alg" begin
105+
u0, p, success = SciMLBase.get_initial_values(
106+
prob, integ, fn, SciMLBase.OverrideInit(),
107+
Val(false); nlsolve_alg = NewtonRaphson())
107108

108-
@test u0 [2.0, 2.0]
109-
@test p 1.0
110-
@test success
109+
@test u0 [2.0, 2.0]
110+
@test p 1.0
111+
@test success
111112

112-
initprob.p[1] = 1.0
113+
initprob.p[1] = 1.0
114+
end
115+
@testset "with alg in `OverrideInit`" begin
116+
u0, p, success = SciMLBase.get_initial_values(
117+
prob, integ, fn, SciMLBase.OverrideInit(nlsolve = NewtonRaphson()),
118+
Val(false))
119+
120+
@test u0 [2.0, 2.0]
121+
@test p 1.0
122+
@test success
123+
124+
initprob.p[1] = 1.0
125+
end
126+
@testset "with trivial problem and no alg" begin
127+
iprob = NonlinearProblem((u, p) -> 0.0, nothing, 1.0)
128+
iprobmap = (_) -> [1.0, 1.0]
129+
initdata = SciMLBase.OverrideInitData(iprob, nothing, iprobmap, nothing)
130+
_fn = ODEFunction(rhs2; initialization_data = initdata)
131+
_prob = ODEProblem(_fn, [2.0, 0.0], (0.0, 1.0), 1.0)
132+
_integ = init(_prob; initializealg = NoInit())
133+
134+
u0, p, success = SciMLBase.get_initial_values(
135+
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false))
136+
137+
@test u0 [1.0, 1.0]
138+
@test p 1.0
139+
@test success
140+
end
113141
end
114142

115143
@testset "Solves with non-integrator value provider" begin

0 commit comments

Comments
 (0)