Skip to content

Commit 3cf6e49

Browse files
feat: require providing tolerances in CheckInit and OverrideInit
1 parent 8c87d4a commit 3cf6e49

File tree

2 files changed

+57
-29
lines changed

2 files changed

+57
-29
lines changed

src/initialization.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,15 @@ _vec(v::AbstractVector) = v
9898
9999
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
100100
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101-
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
101+
`AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument.
102+
103+
Keyword arguments:
104+
- `abstol`: Defaults to `cache_get_abstol(integrator)`, requiring that the integrator implement
105+
`cache_stores_tolerances`.
102106
"""
103107
function get_initial_values(
104108
prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit,
105-
isinplace::Union{Val{true}, Val{false}}; kwargs...)
109+
isinplace::Union{Val{true}, Val{false}}; abstol = nothing, kwargs...)
106110
u0 = state_values(integrator)
107111
p = parameter_values(integrator)
108112
t = current_time(integrator)
@@ -117,8 +121,11 @@ function get_initial_values(
117121

118122
normresid = isdefined(integrator.opts, :internalnorm) ?
119123
integrator.opts.internalnorm(tmp, t) : norm(tmp)
120-
if normresid > integrator.opts.abstol
121-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
124+
if abstol === nothing
125+
abstol = cache_get_abstol(integrator)
126+
end
127+
if normresid > abstol
128+
throw(CheckInitFailureError(normresid, abstol))
122129
end
123130
return u0, p, true
124131
end
@@ -139,16 +146,20 @@ end
139146

140147
function get_initial_values(
141148
prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit,
142-
isinplace::Union{Val{true}, Val{false}}; kwargs...)
149+
isinplace::Union{Val{true}, Val{false}}; abstol = nothing, kwargs...)
143150
u0 = state_values(integrator)
144151
p = parameter_values(integrator)
145152
t = current_time(integrator)
146153

147154
resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
148155
normresid = isdefined(integrator.opts, :internalnorm) ?
149156
integrator.opts.internalnorm(resid, t) : norm(resid)
150-
if normresid > integrator.opts.abstol
151-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
157+
158+
if abstol === nothing
159+
abstol = cache_get_abstol(integrator)
160+
end
161+
if normresid > abstol
162+
throw(CheckInitFailureError(normresid, abstol))
152163
end
153164
return u0, p, true
154165
end
@@ -159,12 +170,21 @@ end
159170
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
160171
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
161172
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
162-
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
163-
argument, failing which this function will throw an error. The success value returned
164-
depends on the success of the nonlinear solve.
173+
174+
The success value returned depends on the success of the nonlinear solve.
175+
176+
Keyword arguments:
177+
- `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will
178+
throw an error.
179+
- `abstol`: The `abstol` to use for the nonlinear solve. Falls back to the `abstol` provided
180+
to `OverrideInit`, and then to `cache_get_abstol`. If none of these contain a tolerance,
181+
throws an error.
182+
- `reltol`: The `reltol` to use for the nonlinear solve. Falls back to the `reltol` provided
183+
to `OverrideInit`, and then to `cache_get_reltol`. If none of these contain a tolerance,
184+
throws an error.
165185
"""
166186
function get_initial_values(prob, valp, f, alg::OverrideInit,
167-
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
187+
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
168188
u0 = state_values(valp)
169189
p = parameter_values(valp)
170190

@@ -184,11 +204,9 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
184204
initdata.update_initializeprob!(initprob, valp)
185205
end
186206

187-
if alg.abstol !== nothing
188-
nlsol = solve(initprob, nlsolve_alg; abstol = alg.abstol)
189-
else
190-
nlsol = solve(initprob, nlsolve_alg)
191-
end
207+
_abstol = @something(alg.abstol, abstol, cache_get_abstol(valp))
208+
_reltol = @something(alg.reltol, reltol, cache_get_reltol(valp))
209+
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
192210

193211
u0 = initdata.initializeprobmap(nlsol)
194212
if initdata.initializeprobpmap !== nothing

test/initialization.jl

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac
1717
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
1818
integ = init(prob)
1919
u0, _, success = SciMLBase.get_initial_values(
20-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
20+
prob, integ, f, SciMLBase.CheckInit(),
21+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
2122
@test success
2223
@test u0 == prob.u0
2324

2425
integ.u[2] = 2.0
2526
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
26-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
27+
prob, integ, f, SciMLBase.CheckInit(),
28+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
2729
end
2830
end
2931

@@ -43,18 +45,21 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac
4345
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
4446
integ = init(prob, DImplicitEuler())
4547
u0, _, success = SciMLBase.get_initial_values(
46-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
48+
prob, integ, f, SciMLBase.CheckInit(),
49+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
4750
@test success
4851
@test u0 == prob.u0
4952

5053
integ.u[2] = 2.0
5154
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
52-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
55+
prob, integ, f, SciMLBase.CheckInit(),
56+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
5357

5458
integ.u[2] = 1.0
5559
integ.du[1] = 2.0
5660
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
57-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
61+
prob, integ, f, SciMLBase.CheckInit(),
62+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
5863
end
5964
end
6065

@@ -86,13 +91,15 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac
8691
prob = SDEProblem(f, [1.0, 1.0, -1.0], (0.0, 1.0))
8792
integ = init(prob, ImplicitEM())
8893
u0, _, success = SciMLBase.get_initial_values(
89-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
94+
prob, integ, f, SciMLBase.CheckInit(),
95+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
9096
@test success
9197
@test u0 == prob.u0
9298

9399
integ.u[2] = 2.0
94100
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
95-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
101+
prob, integ, f, SciMLBase.CheckInit(),
102+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
96103
end
97104
end
98105
end
@@ -138,11 +145,13 @@ end
138145
prob, integ, fn, SciMLBase.OverrideInit(), Val(false))
139146
end
140147

148+
abstol = 1e-10
149+
reltol = 1e-10
141150
@testset "Solves" begin
142151
@testset "with explicit alg" begin
143152
u0, p, success = SciMLBase.get_initial_values(
144153
prob, integ, fn, SciMLBase.OverrideInit(),
145-
Val(false); nlsolve_alg = NewtonRaphson())
154+
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)
146155

147156
@test u0 [2.0, 2.0]
148157
@test p 1.0
@@ -152,7 +161,8 @@ end
152161
end
153162
@testset "with alg in `OverrideInit`" begin
154163
u0, p, success = SciMLBase.get_initial_values(
155-
prob, integ, fn, SciMLBase.OverrideInit(nlsolve = NewtonRaphson()),
164+
prob, integ, fn,
165+
SciMLBase.OverrideInit(; nlsolve = NewtonRaphson(), abstol, reltol),
156166
Val(false))
157167

158168
@test u0 [2.0, 2.0]
@@ -170,7 +180,7 @@ end
170180
_integ = init(_prob; initializealg = NoInit())
171181

172182
u0, p, success = SciMLBase.get_initial_values(
173-
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false))
183+
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false); abstol, reltol)
174184

175185
@test u0 [1.0, 1.0]
176186
@test p 1.0
@@ -182,7 +192,7 @@ end
182192
_integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t)
183193
u0, p, success = SciMLBase.get_initial_values(
184194
prob, _integ, fn, SciMLBase.OverrideInit(),
185-
Val(false); nlsolve_alg = NewtonRaphson())
195+
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)
186196

187197
@test u0 [2.0, 2.0]
188198
@test p 1.0
@@ -199,7 +209,7 @@ end
199209

200210
u0, p, success = SciMLBase.get_initial_values(
201211
prob, integ, fn, SciMLBase.OverrideInit(),
202-
Val(false); nlsolve_alg = NewtonRaphson())
212+
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)
203213
@test u0 [1.0, 1.0]
204214
@test p 1.0
205215
@test success
@@ -213,7 +223,7 @@ end
213223

214224
u0, p, success = SciMLBase.get_initial_values(
215225
prob, integ, fn, SciMLBase.OverrideInit(),
216-
Val(false); nlsolve_alg = NewtonRaphson())
226+
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)
217227

218228
@test u0 [2.0, 2.0]
219229
@test p 0.0

0 commit comments

Comments
 (0)