Skip to content

Commit 64e2ee0

Browse files
feat: require providing tolerances in CheckInit and OverrideInit
1 parent 6e38d68 commit 64e2ee0

File tree

2 files changed

+71
-26
lines changed

2 files changed

+71
-26
lines changed

src/initialization.jl

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
6868
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
6969
end
7070

71+
struct OverrideInitNoTolerance <: Exception
72+
tolerance::Symbol
73+
end
74+
75+
function Base.showerror(io::IO, e::OverrideInitNoTolerance)
76+
print(io,
77+
"Tolerances were not provided to `OverrideInit`. `$(e.tolerance)` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor.")
78+
end
79+
7180
"""
7281
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
7382
it is in-place or simply calling the function if not.
@@ -98,11 +107,16 @@ _vec(v::AbstractVector) = v
98107
99108
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
100109
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101-
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
110+
`AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument.
111+
112+
Keyword arguments:
113+
- `abstol`: The absolute value below which the norm of the residual of algebraic equations
114+
should lie. The norm function used is `integrator.opts.internalnorm` if present, and
115+
`LinearAlgebra.norm` if not.
102116
"""
103117
function get_initial_values(
104118
prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit,
105-
isinplace::Union{Val{true}, Val{false}}; kwargs...)
119+
isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...)
106120
u0 = state_values(integrator)
107121
p = parameter_values(integrator)
108122
t = current_time(integrator)
@@ -117,8 +131,8 @@ function get_initial_values(
117131

118132
normresid = isdefined(integrator.opts, :internalnorm) ?
119133
integrator.opts.internalnorm(tmp, t) : norm(tmp)
120-
if normresid > integrator.opts.abstol
121-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
134+
if normresid > abstol
135+
throw(CheckInitFailureError(normresid, abstol))
122136
end
123137
return u0, p, true
124138
end
@@ -139,16 +153,20 @@ end
139153

140154
function get_initial_values(
141155
prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit,
142-
isinplace::Union{Val{true}, Val{false}}; kwargs...)
156+
isinplace::Union{Val{true}, Val{false}}; abstol = nothing, kwargs...)
143157
u0 = state_values(integrator)
144158
p = parameter_values(integrator)
145159
t = current_time(integrator)
146160

147161
resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
148162
normresid = isdefined(integrator.opts, :internalnorm) ?
149163
integrator.opts.internalnorm(resid, t) : norm(resid)
150-
if normresid > integrator.opts.abstol
151-
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
164+
165+
if abstol === nothing
166+
abstol = cache_get_abstol(integrator)
167+
end
168+
if normresid > abstol
169+
throw(CheckInitFailureError(normresid, abstol))
152170
end
153171
return u0, p, true
154172
end
@@ -159,12 +177,19 @@ end
159177
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
160178
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
161179
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
162-
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
163-
argument, failing which this function will throw an error. The success value returned
164-
depends on the success of the nonlinear solve.
180+
181+
The success value returned depends on the success of the nonlinear solve.
182+
183+
Keyword arguments:
184+
- `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will
185+
throw an error.
186+
- `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value
187+
provided to the `OverrideInit` constructor takes priority over this keyword argument.
188+
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
189+
an error will be thrown.
165190
"""
166191
function get_initial_values(prob, valp, f, alg::OverrideInit,
167-
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
192+
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
168193
u0 = state_values(valp)
169194
p = parameter_values(valp)
170195

@@ -185,10 +210,20 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
185210
end
186211

187212
if alg.abstol !== nothing
188-
nlsol = solve(initprob, nlsolve_alg; abstol = alg.abstol)
213+
_abstol = alg.abstol
214+
elseif abstol !== nothing
215+
_abstol = abstol
216+
else
217+
throw(OverrideInitNoTolerance(:abstol))
218+
end
219+
if alg.reltol !== nothing
220+
_reltol = alg.reltol
221+
elseif reltol !== nothing
222+
_reltol = reltol
189223
else
190-
nlsol = solve(initprob, nlsolve_alg)
224+
throw(OverrideInitNoTolerance(:reltol))
191225
end
226+
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
192227

193228
u0 = initdata.initializeprobmap(nlsol)
194229
if initdata.initializeprobpmap !== nothing

test/initialization.jl

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac
1717
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
1818
integ = init(prob)
1919
u0, _, success = SciMLBase.get_initial_values(
20-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
20+
prob, integ, f, SciMLBase.CheckInit(),
21+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
2122
@test success
2223
@test u0 == prob.u0
2324

2425
integ.u[2] = 2.0
2526
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
26-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
27+
prob, integ, f, SciMLBase.CheckInit(),
28+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
2729
end
2830
end
2931

@@ -43,18 +45,21 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac
4345
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
4446
integ = init(prob, DImplicitEuler())
4547
u0, _, success = SciMLBase.get_initial_values(
46-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
48+
prob, integ, f, SciMLBase.CheckInit(),
49+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
4750
@test success
4851
@test u0 == prob.u0
4952

5053
integ.u[2] = 2.0
5154
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
52-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
55+
prob, integ, f, SciMLBase.CheckInit(),
56+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
5357

5458
integ.u[2] = 1.0
5559
integ.du[1] = 2.0
5660
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
57-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
61+
prob, integ, f, SciMLBase.CheckInit(),
62+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
5863
end
5964
end
6065

@@ -86,13 +91,15 @@ using StochasticDiffEq, OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterfac
8691
prob = SDEProblem(f, [1.0, 1.0, -1.0], (0.0, 1.0))
8792
integ = init(prob, ImplicitEM())
8893
u0, _, success = SciMLBase.get_initial_values(
89-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
94+
prob, integ, f, SciMLBase.CheckInit(),
95+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
9096
@test success
9197
@test u0 == prob.u0
9298

9399
integ.u[2] = 2.0
94100
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
95-
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
101+
prob, integ, f, SciMLBase.CheckInit(),
102+
Val(SciMLBase.isinplace(f)); abstol = 1e-10)
96103
end
97104
end
98105
end
@@ -138,11 +145,13 @@ end
138145
prob, integ, fn, SciMLBase.OverrideInit(), Val(false))
139146
end
140147

148+
abstol = 1e-10
149+
reltol = 1e-10
141150
@testset "Solves" begin
142151
@testset "with explicit alg" begin
143152
u0, p, success = SciMLBase.get_initial_values(
144153
prob, integ, fn, SciMLBase.OverrideInit(),
145-
Val(false); nlsolve_alg = NewtonRaphson())
154+
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)
146155

147156
@test u0 [2.0, 2.0]
148157
@test p 1.0
@@ -152,7 +161,8 @@ end
152161
end
153162
@testset "with alg in `OverrideInit`" begin
154163
u0, p, success = SciMLBase.get_initial_values(
155-
prob, integ, fn, SciMLBase.OverrideInit(nlsolve = NewtonRaphson()),
164+
prob, integ, fn,
165+
SciMLBase.OverrideInit(; nlsolve = NewtonRaphson(), abstol, reltol),
156166
Val(false))
157167

158168
@test u0 [2.0, 2.0]
@@ -170,7 +180,7 @@ end
170180
_integ = init(_prob; initializealg = NoInit())
171181

172182
u0, p, success = SciMLBase.get_initial_values(
173-
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false))
183+
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false); abstol, reltol)
174184

175185
@test u0 [1.0, 1.0]
176186
@test p 1.0
@@ -182,7 +192,7 @@ end
182192
_integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t)
183193
u0, p, success = SciMLBase.get_initial_values(
184194
prob, _integ, fn, SciMLBase.OverrideInit(),
185-
Val(false); nlsolve_alg = NewtonRaphson())
195+
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)
186196

187197
@test u0 [2.0, 2.0]
188198
@test p 1.0
@@ -199,7 +209,7 @@ end
199209

200210
u0, p, success = SciMLBase.get_initial_values(
201211
prob, integ, fn, SciMLBase.OverrideInit(),
202-
Val(false); nlsolve_alg = NewtonRaphson())
212+
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)
203213
@test u0 [1.0, 1.0]
204214
@test p 1.0
205215
@test success
@@ -213,7 +223,7 @@ end
213223

214224
u0, p, success = SciMLBase.get_initial_values(
215225
prob, integ, fn, SciMLBase.OverrideInit(),
216-
Val(false); nlsolve_alg = NewtonRaphson())
226+
Val(false); nlsolve_alg = NewtonRaphson(), abstol, reltol)
217227

218228
@test u0 [2.0, 2.0]
219229
@test p 0.0

0 commit comments

Comments
 (0)