@@ -9,8 +9,12 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
9
9
"""
10
10
initializeprob:: IProb
11
11
"""
12
- A function which takes `(initializeprob, prob )` and updates
12
+ A function which takes `(initializeprob, value_provider )` and updates
13
13
the parameters of the former with their values in the latter.
14
+ If absent (`nothing`) this will not be called, and the parameters
15
+ in `initializeprob` will be used without modification. `value_provider`
16
+ refers to a value provider as defined by SymbolicIndexingInterface.jl.
17
+ Usually this will refer to a problem or integrator.
14
18
"""
15
19
update_initializeprob!:: UIProb
16
20
"""
@@ -20,7 +24,9 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
20
24
initializeprobmap:: IProbMap
21
25
"""
22
26
A function which takes the solution of `initializeprob` and returns
23
- the parameter object of the original problem.
27
+ the parameter object of the original problem. If absent (`nothing`),
28
+ this will not be called and the parameters of the problem being
29
+ initialized will be returned as-is.
24
30
"""
25
31
initializeprobpmap:: IProbPmap
26
32
@@ -30,3 +36,155 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
30
36
return new {I, J, K, L} (initprob, update_initprob!, initprobmap, initprobpmap)
31
37
end
32
38
end
39
+
40
+ """
41
+ get_initial_values(prob, valp, f, alg, isinplace; kwargs...)
42
+
43
+ Return the initial `u0` and `p` for the given SciMLProblem and initialization algorithm,
44
+ and a boolean indicating whether the initialization process was successful. Keyword
45
+ arguments to this function are dependent on the initialization algorithm. `prob` is only
46
+ required for dispatching. `valp` refers the appropriate data structure from which the
47
+ current state and parameter values should be obtained. `valp` is a non-timeseries value
48
+ provider as defined by SymbolicIndexingInterface.jl. `f` is the SciMLFunction for the
49
+ problem. `alg` is the initialization algorithm to use. `isinplace` is either `Val{true}`
50
+ if `valp` and the SciMLFunction are inplace, and `Val{false}` otherwise.
51
+ """
52
+ function get_initial_values end
53
+
54
+ struct CheckInitFailureError <: Exception
55
+ normresid:: Any
56
+ abstol:: Any
57
+ end
58
+
59
+ function Base. showerror (io:: IO , e:: CheckInitFailureError )
60
+ print (io,
61
+ " CheckInit specified but initialization not satisfied. normresid = $(e. normresid) > abstol = $(e. abstol) " )
62
+ end
63
+
64
+ struct OverrideInitMissingAlgorithm <: Exception end
65
+
66
+ function Base. showerror (io:: IO , e:: OverrideInitMissingAlgorithm )
67
+ print (io,
68
+ " OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`." )
69
+ end
70
+
71
+ """
72
+ Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
73
+ it is in-place or simply calling the function if not.
74
+ """
75
+ function _evaluate_f_ode (integrator, f, isinplace:: Val{true} , args... )
76
+ tmp = first (get_tmp_cache (integrator))
77
+ f (tmp, args... )
78
+ return tmp
79
+ end
80
+
81
+ function _evaluate_f_ode (integrator, f, isinplace:: Val{false} , args... )
82
+ return f (args... )
83
+ end
84
+
85
+ """
86
+ $(TYPEDSIGNATURES)
87
+
88
+ A utility function equivalent to `Base.vec` but also handles `Number` and
89
+ `AbstractSciMLScalarOperator`.
90
+ """
91
+ _vec (v) = vec (v)
92
+ _vec (v:: Number ) = v
93
+ _vec (v:: SciMLOperators.AbstractSciMLScalarOperator ) = v
94
+ _vec (v:: AbstractVector ) = v
95
+
96
+ """
97
+ $(TYPEDSIGNATURES)
98
+
99
+ Check if the algebraic constraints are satisfied, and error if they aren't. Returns
100
+ the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101
+ `ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
102
+ """
103
+ function get_initial_values (prob:: ODEProblem , integrator, f, alg:: CheckInit ,
104
+ isinplace:: Union{Val{true}, Val{false}} ; kwargs... )
105
+ u0 = state_values (integrator)
106
+ p = parameter_values (integrator)
107
+ t = current_time (integrator)
108
+ M = f. mass_matrix
109
+
110
+ algebraic_vars = [all (iszero, x) for x in eachcol (M)]
111
+ algebraic_eqs = [all (iszero, x) for x in eachrow (M)]
112
+ (iszero (algebraic_vars) || iszero (algebraic_eqs)) && return
113
+ update_coefficients! (M, u0, p, t)
114
+ tmp = _evaluate_f_ode (integrator, f, isinplace, u0, p, t)
115
+ tmp .= ArrayInterface. restructure (tmp, algebraic_eqs .* _vec (tmp))
116
+
117
+ normresid = integrator. opts. internalnorm (tmp, t)
118
+ if normresid > integrator. opts. abstol
119
+ throw (CheckInitFailureError (normresid, integrator. opts. abstol))
120
+ end
121
+ return u0, p, true
122
+ end
123
+
124
+ """
125
+ Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
126
+ it is in-place or simply calling the function if not.
127
+ """
128
+ function _evaluate_f_dae (integrator, f, isinplace:: Val{true} , args... )
129
+ tmp = get_tmp_cache (integrator)[2 ]
130
+ f (tmp, args... )
131
+ return tmp
132
+ end
133
+
134
+ function _evaluate_f_dae (integrator, f, isinplace:: Val{false} , args... )
135
+ return f (args... )
136
+ end
137
+
138
+ function get_initial_values (prob:: DAEProblem , integrator, f, alg:: CheckInit ,
139
+ isinplace:: Union{Val{true}, Val{false}} ; kwargs... )
140
+ u0 = state_values (integrator)
141
+ p = parameter_values (integrator)
142
+ t = current_time (integrator)
143
+
144
+ resid = _evaluate_f_dae (integrator, f, isinplace, integrator. du, u0, p, t)
145
+ normresid = integrator. opts. internalnorm (resid, t)
146
+ if normresid > integrator. opts. abstol
147
+ throw (CheckInitFailureError (normresid, integrator. opts. abstol))
148
+ end
149
+ return u0, p, true
150
+ end
151
+
152
+ """
153
+ $(TYPEDSIGNATURES)
154
+
155
+ Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
156
+ `p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
157
+ If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
158
+ The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
159
+ argument, failing which this function will throw an error. The success value returned
160
+ depends on the success of the nonlinear solve.
161
+ """
162
+ function get_initial_values (prob, valp, f, alg:: OverrideInit ,
163
+ isinplace:: Union{Val{true}, Val{false}} ; nlsolve_alg = nothing , kwargs... )
164
+ u0 = state_values (valp)
165
+ p = parameter_values (valp)
166
+
167
+ if ! has_initialization_data (f)
168
+ return u0, p, true
169
+ end
170
+
171
+ initdata:: OverrideInitData = f. initialization_data
172
+ initprob = initdata. initializeprob
173
+
174
+ if nlsolve_alg === nothing
175
+ throw (OverrideInitMissingAlgorithm ())
176
+ end
177
+
178
+ if initdata. update_initializeprob! != = nothing
179
+ initdata. update_initializeprob! (initprob, valp)
180
+ end
181
+
182
+ nlsol = solve (initprob, nlsolve_alg)
183
+
184
+ u0 = initdata. initializeprobmap (nlsol)
185
+ if initdata. initializeprobpmap != = nothing
186
+ p = initdata. initializeprobpmap (nlsol)
187
+ end
188
+
189
+ return u0, p, SciMLBase. successful_retcode (nlsol)
190
+ end
0 commit comments