@@ -9,8 +9,12 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
99 """
1010 initializeprob:: IProb
1111 """
12- A function which takes `(initializeprob, prob )` and updates
12+ A function which takes `(initializeprob, value_provider )` and updates
1313 the parameters of the former with their values in the latter.
14+ If absent (`nothing`) this will not be called, and the parameters
15+ in `initializeprob` will be used without modification. `value_provider`
16+ refers to a value provider as defined by SymbolicIndexingInterface.jl.
17+ Usually this will refer to a problem or integrator.
1418 """
1519 update_initializeprob!:: UIProb
1620 """
@@ -20,7 +24,9 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
2024 initializeprobmap:: IProbMap
2125 """
2226 A function which takes the solution of `initializeprob` and returns
23- the parameter object of the original problem.
27+ the parameter object of the original problem. If absent (`nothing`),
28+ this will not be called and the parameters of the problem being
29+ initialized will be returned as-is.
2430 """
2531 initializeprobpmap:: IProbPmap
2632
@@ -30,3 +36,155 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
3036 return new {I, J, K, L} (initprob, update_initprob!, initprobmap, initprobpmap)
3137 end
3238end
39+
40+ """
41+ get_initial_values(prob, valp, f, alg, isinplace; kwargs...)
42+
43+ Return the initial `u0` and `p` for the given SciMLProblem and initialization algorithm,
44+ and a boolean indicating whether the initialization process was successful. Keyword
45+ arguments to this function are dependent on the initialization algorithm. `prob` is only
46+ required for dispatching. `valp` refers the appropriate data structure from which the
47+ current state and parameter values should be obtained. `valp` is a non-timeseries value
48+ provider as defined by SymbolicIndexingInterface.jl. `f` is the SciMLFunction for the
49+ problem. `alg` is the initialization algorithm to use. `isinplace` is either `Val{true}`
50+ if `valp` and the SciMLFunction are inplace, and `Val{false}` otherwise.
51+ """
52+ function get_initial_values end
53+
54+ struct CheckInitFailureError <: Exception
55+ normresid:: Any
56+ abstol:: Any
57+ end
58+
59+ function Base. showerror (io:: IO , e:: CheckInitFailureError )
60+ print (io,
61+ " CheckInit specified but initialization not satisfied. normresid = $(e. normresid) > abstol = $(e. abstol) " )
62+ end
63+
64+ struct OverrideInitMissingAlgorithm <: Exception end
65+
66+ function Base. showerror (io:: IO , e:: OverrideInitMissingAlgorithm )
67+ print (io,
68+ " OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`." )
69+ end
70+
71+ """
72+ Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
73+ it is in-place or simply calling the function if not.
74+ """
75+ function _evaluate_f_ode (integrator, f, isinplace:: Val{true} , args... )
76+ tmp = first (get_tmp_cache (integrator))
77+ f (tmp, args... )
78+ return tmp
79+ end
80+
81+ function _evaluate_f_ode (integrator, f, isinplace:: Val{false} , args... )
82+ return f (args... )
83+ end
84+
85+ """
86+ $(TYPEDSIGNATURES)
87+
88+ A utility function equivalent to `Base.vec` but also handles `Number` and
89+ `AbstractSciMLScalarOperator`.
90+ """
91+ _vec (v) = vec (v)
92+ _vec (v:: Number ) = v
93+ _vec (v:: SciMLOperators.AbstractSciMLScalarOperator ) = v
94+ _vec (v:: AbstractVector ) = v
95+
96+ """
97+ $(TYPEDSIGNATURES)
98+
99+ Check if the algebraic constraints are satisfied, and error if they aren't. Returns
100+ the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101+ `ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
102+ """
103+ function get_initial_values (prob:: ODEProblem , integrator, f, alg:: CheckInit ,
104+ isinplace:: Union{Val{true}, Val{false}} ; kwargs... )
105+ u0 = state_values (integrator)
106+ p = parameter_values (integrator)
107+ t = current_time (integrator)
108+ M = f. mass_matrix
109+
110+ algebraic_vars = [all (iszero, x) for x in eachcol (M)]
111+ algebraic_eqs = [all (iszero, x) for x in eachrow (M)]
112+ (iszero (algebraic_vars) || iszero (algebraic_eqs)) && return
113+ update_coefficients! (M, u0, p, t)
114+ tmp = _evaluate_f_ode (integrator, f, isinplace, u0, p, t)
115+ tmp .= ArrayInterface. restructure (tmp, algebraic_eqs .* _vec (tmp))
116+
117+ normresid = integrator. opts. internalnorm (tmp, t)
118+ if normresid > integrator. opts. abstol
119+ throw (CheckInitFailureError (normresid, integrator. opts. abstol))
120+ end
121+ return u0, p, true
122+ end
123+
124+ """
125+ Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
126+ it is in-place or simply calling the function if not.
127+ """
128+ function _evaluate_f_dae (integrator, f, isinplace:: Val{true} , args... )
129+ tmp = get_tmp_cache (integrator)[2 ]
130+ f (tmp, args... )
131+ return tmp
132+ end
133+
134+ function _evaluate_f_dae (integrator, f, isinplace:: Val{false} , args... )
135+ return f (args... )
136+ end
137+
138+ function get_initial_values (prob:: DAEProblem , integrator, f, alg:: CheckInit ,
139+ isinplace:: Union{Val{true}, Val{false}} ; kwargs... )
140+ u0 = state_values (integrator)
141+ p = parameter_values (integrator)
142+ t = current_time (integrator)
143+
144+ resid = _evaluate_f_dae (integrator, f, isinplace, integrator. du, u0, p, t)
145+ normresid = integrator. opts. internalnorm (resid, t)
146+ if normresid > integrator. opts. abstol
147+ throw (CheckInitFailureError (normresid, integrator. opts. abstol))
148+ end
149+ return u0, p, true
150+ end
151+
152+ """
153+ $(TYPEDSIGNATURES)
154+
155+ Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
156+ `p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
157+ If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
158+ The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
159+ argument, failing which this function will throw an error. The success value returned
160+ depends on the success of the nonlinear solve.
161+ """
162+ function get_initial_values (prob, valp, f, alg:: OverrideInit ,
163+ isinplace:: Union{Val{true}, Val{false}} ; nlsolve_alg = nothing , kwargs... )
164+ u0 = state_values (valp)
165+ p = parameter_values (valp)
166+
167+ if ! has_initialization_data (f)
168+ return u0, p, true
169+ end
170+
171+ initdata:: OverrideInitData = f. initialization_data
172+ initprob = initdata. initializeprob
173+
174+ if nlsolve_alg === nothing
175+ throw (OverrideInitMissingAlgorithm ())
176+ end
177+
178+ if initdata. update_initializeprob! != = nothing
179+ initdata. update_initializeprob! (initprob, valp)
180+ end
181+
182+ nlsol = solve (initprob, nlsolve_alg)
183+
184+ u0 = initdata. initializeprobmap (nlsol)
185+ if initdata. initializeprobpmap != = nothing
186+ p = initdata. initializeprobpmap (nlsol)
187+ end
188+
189+ return u0, p, SciMLBase. successful_retcode (nlsol)
190+ end
0 commit comments