@@ -68,17 +68,26 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
6868 " OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`." )
6969end
7070
71+ struct OverrideInitNoTolerance <: Exception
72+ tolerance:: Symbol
73+ end
74+
75+ function Base. showerror (io:: IO , e:: OverrideInitNoTolerance )
76+ print (io,
77+ " Tolerances were not provided to `OverrideInit`. `$(e. tolerance) ` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor." )
78+ end
79+
7180"""
72- Utility function to evaluate the RHS of the ODE , using the integrator's `tmp_cache` if
81+ Utility function to evaluate the RHS, using the integrator's `tmp_cache` if
7382it is in-place or simply calling the function if not.
7483"""
75- function _evaluate_f_ode (integrator, f, isinplace:: Val{true} , args... )
84+ function _evaluate_f (integrator, f, isinplace:: Val{true} , args... )
7685 tmp = first (get_tmp_cache (integrator))
7786 f (tmp, args... )
7887 return tmp
7988end
8089
81- function _evaluate_f_ode (integrator, f, isinplace:: Val{false} , args... )
90+ function _evaluate_f (integrator, f, isinplace:: Val{false} , args... )
8291 return f (args... )
8392end
8493
@@ -98,53 +107,49 @@ _vec(v::AbstractVector) = v
98107
99108Check if the algebraic constraints are satisfied, and error if they aren't. Returns
100109the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101- `ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
110+ `AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument.
111+
112+ Keyword arguments:
113+ - `abstol`: The absolute value below which the norm of the residual of algebraic equations
114+ should lie. The norm function used is `integrator.opts.internalnorm` if present, and
115+ `LinearAlgebra.norm` if not.
102116"""
103- function get_initial_values (prob:: AbstractODEProblem , integrator, f, alg:: CheckInit ,
104- isinplace:: Union{Val{true}, Val{false}} ; kwargs... )
117+ function get_initial_values (
118+ prob:: AbstractDEProblem , integrator:: DEIntegrator , f, alg:: CheckInit ,
119+ isinplace:: Union{Val{true}, Val{false}} ; abstol, kwargs... )
105120 u0 = state_values (integrator)
106121 p = parameter_values (integrator)
107122 t = current_time (integrator)
108123 M = f. mass_matrix
109124
110125 algebraic_vars = [all (iszero, x) for x in eachcol (M)]
111126 algebraic_eqs = [all (iszero, x) for x in eachrow (M)]
112- (iszero (algebraic_vars) || iszero (algebraic_eqs)) && return
127+ (iszero (algebraic_vars) || iszero (algebraic_eqs)) && return u0, p, true
113128 update_coefficients! (M, u0, p, t)
114- tmp = _evaluate_f_ode (integrator, f, isinplace, u0, p, t)
129+ tmp = _evaluate_f (integrator, f, isinplace, u0, p, t)
115130 tmp .= ArrayInterface. restructure (tmp, algebraic_eqs .* _vec (tmp))
116131
117- normresid = integrator. opts. internalnorm (tmp, t)
118- if normresid > integrator. opts. abstol
119- throw (CheckInitFailureError (normresid, integrator. opts. abstol))
132+ normresid = isdefined (integrator. opts, :internalnorm ) ?
133+ integrator. opts. internalnorm (tmp, t) : norm (tmp)
134+ if normresid > abstol
135+ throw (CheckInitFailureError (normresid, abstol))
120136 end
121137 return u0, p, true
122138end
123139
124- """
125- Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
126- it is in-place or simply calling the function if not.
127- """
128- function _evaluate_f_dae (integrator, f, isinplace:: Val{true} , args... )
129- tmp = get_tmp_cache (integrator)[2 ]
130- f (tmp, args... )
131- return tmp
132- end
133-
134- function _evaluate_f_dae (integrator, f, isinplace:: Val{false} , args... )
135- return f (args... )
136- end
137-
138- function get_initial_values (prob:: AbstractDAEProblem , integrator, f, alg:: CheckInit ,
139- isinplace:: Union{Val{true}, Val{false}} ; kwargs... )
140+ function get_initial_values (
141+ prob:: AbstractDAEProblem , integrator:: DEIntegrator , f, alg:: CheckInit ,
142+ isinplace:: Union{Val{true}, Val{false}} ; abstol, kwargs... )
140143 u0 = state_values (integrator)
141144 p = parameter_values (integrator)
142145 t = current_time (integrator)
143146
144- resid = _evaluate_f_dae (integrator, f, isinplace, integrator. du, u0, p, t)
145- normresid = integrator. opts. internalnorm (resid, t)
146- if normresid > integrator. opts. abstol
147- throw (CheckInitFailureError (normresid, integrator. opts. abstol))
147+ resid = _evaluate_f (integrator, f, isinplace, integrator. du, u0, p, t)
148+ normresid = isdefined (integrator. opts, :internalnorm ) ?
149+ integrator. opts. internalnorm (resid, t) : norm (resid)
150+
151+ if normresid > abstol
152+ throw (CheckInitFailureError (normresid, abstol))
148153 end
149154 return u0, p, true
150155end
@@ -155,12 +160,19 @@ end
155160Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
156161`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
157162If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
158- The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
159- argument, failing which this function will throw an error. The success value returned
160- depends on the success of the nonlinear solve.
163+
164+ The success value returned depends on the success of the nonlinear solve.
165+
166+ Keyword arguments:
167+ - `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will
168+ throw an error.
169+ - `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value
170+ provided to the `OverrideInit` constructor takes priority over this keyword argument.
171+ If the former is `nothing`, this keyword argument will be used. If it is also not provided,
172+ an error will be thrown.
161173"""
162174function get_initial_values (prob, valp, f, alg:: OverrideInit ,
163- isinplace :: Union{Val{true}, Val{false}} ; nlsolve_alg = nothing , kwargs... )
175+ iip :: Union{Val{true}, Val{false}} ; nlsolve_alg = nothing , abstol = nothing , reltol = nothing , kwargs... )
164176 u0 = state_values (valp)
165177 p = parameter_values (valp)
166178
@@ -171,15 +183,30 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
171183 initdata:: OverrideInitData = f. initialization_data
172184 initprob = initdata. initializeprob
173185
174- if nlsolve_alg === nothing
186+ nlsolve_alg = something (nlsolve_alg, alg. nlsolve, Some (nothing ))
187+ if nlsolve_alg === nothing && state_values (initprob) != = nothing
175188 throw (OverrideInitMissingAlgorithm ())
176189 end
177190
178191 if initdata. update_initializeprob! != = nothing
179192 initdata. update_initializeprob! (initprob, valp)
180193 end
181194
182- nlsol = solve (initprob, nlsolve_alg)
195+ if alg. abstol != = nothing
196+ _abstol = alg. abstol
197+ elseif abstol != = nothing
198+ _abstol = abstol
199+ else
200+ throw (OverrideInitNoTolerance (:abstol ))
201+ end
202+ if alg. reltol != = nothing
203+ _reltol = alg. reltol
204+ elseif reltol != = nothing
205+ _reltol = reltol
206+ else
207+ throw (OverrideInitNoTolerance (:reltol ))
208+ end
209+ nlsol = solve (initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
183210
184211 u0 = initdata. initializeprobmap (nlsol)
185212 if initdata. initializeprobpmap != = nothing
0 commit comments