@@ -107,9 +107,163 @@ end
107
107
maybe_codegen_scimlproblem (expression, SteadyStateProblem{iip}, args; kwargs... )
108
108
end
109
109
110
+ @fallback_iip_specialize function SemilinearODEFunction {iip, specialize} (
111
+ sys:: System ; u0 = nothing , p = nothing , t = nothing ,
112
+ semiquadratic_form = nothing ,
113
+ stiff_linear = true , stiff_quadratic = false , stiff_nonlinear = false ,
114
+ eval_expression = false , eval_module = @__MODULE__ ,
115
+ expression = Val{false }, sparse = false , check_compatibility = true ,
116
+ jac = false , checkbounds = false , cse = true , initialization_data = nothing ,
117
+ analytic = nothing , kwargs... ) where {iip, specialize}
118
+ check_complete (sys, SemilinearODEFunction)
119
+ check_compatibility && check_compatible_system (SemilinearODEFunction, sys)
120
+
121
+ if semiquadratic_form === nothing
122
+ semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
123
+ sys = add_semiquadratic_parameters (sys, semiquadratic_form... )
124
+ end
125
+
126
+ A, B, C = semiquadratic_form
127
+ M = calculate_massmatrix (sys)
128
+ _M = concrete_massmatrix (M; sparse, u0)
129
+ dvs = unknowns (sys)
130
+
131
+ f1,
132
+ f2 = generate_semiquadratic_functions (
133
+ sys, A, B, C; stiff_linear, stiff_quadratic,
134
+ stiff_nonlinear, expression, wrap_gfw = Val{true },
135
+ eval_expression, eval_module, kwargs... )
136
+
137
+ if jac
138
+ Cjac = (C === nothing || ! stiff_nonlinear) ? nothing : Symbolics. jacobian (C, dvs)
139
+ _jac = generate_semiquadratic_jacobian (
140
+ sys, A, B, C, Cjac; sparse, expression,
141
+ wrap_gfw = Val{true }, eval_expression, eval_module, kwargs... )
142
+ _W_sparsity = get_semiquadratic_W_sparsity (
143
+ sys, A, B, C, Cjac; stiff_linear, stiff_quadratic, stiff_nonlinear, mm = M)
144
+ W_prototype = calculate_W_prototype (_W_sparsity; u0, sparse)
145
+ else
146
+ _jac = nothing
147
+ W_prototype = nothing
148
+ end
149
+
150
+ observedfun = ObservedFunctionCache (
151
+ sys; expression, steady_state = false , eval_expression, eval_module, checkbounds, cse)
152
+
153
+ args = (; f1)
154
+ kwargs = (; jac = _jac, jac_prototype = W_prototype)
155
+ f1 = maybe_codegen_scimlfn (expression, ODEFunction{iip, specialize}, args; kwargs... )
156
+
157
+ args = (; f1, f2)
158
+ kwargs = (;
159
+ sys = sys,
160
+ jac = _jac,
161
+ mass_matrix = _M,
162
+ jac_prototype = W_prototype,
163
+ observed = observedfun,
164
+ analytic,
165
+ initialization_data)
166
+
167
+ return maybe_codegen_scimlfn (
168
+ expression, SplitFunction{iip, specialize}, args; kwargs... )
169
+ end
170
+
171
+ @fallback_iip_specialize function SemilinearODEProblem {iip, spec} (
172
+ sys:: System , op, tspan; check_compatibility = true , u0_eltype = nothing ,
173
+ expression = Val{false }, callback = nothing , sparse = false ,
174
+ stiff_linear = true , stiff_quadratic = false , stiff_nonlinear = false ,
175
+ jac = false , kwargs... ) where {
176
+ iip, spec}
177
+ check_complete (sys, SemilinearODEProblem)
178
+ check_compatibility && check_compatible_system (SemilinearODEProblem, sys)
179
+
180
+ A, B, C = semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
181
+ eqs = equations (sys)
182
+ dvs = unknowns (sys)
183
+
184
+ sys = add_semiquadratic_parameters (sys, A, B, C)
185
+ if A != = nothing
186
+ linear_matrix_param = unwrap (getproperty (sys, LINEAR_MATRIX_PARAM_NAME))
187
+ else
188
+ linear_matrix_param = nothing
189
+ end
190
+ if B != = nothing
191
+ quadratic_forms = [unwrap (getproperty (sys, get_quadratic_form_name (i)))
192
+ for i in 1 : length (eqs)]
193
+ diffcache_par = unwrap (getproperty (sys, DIFFCACHE_PARAM_NAME))
194
+ else
195
+ quadratic_forms = diffcache_par = nothing
196
+ end
197
+
198
+ op = to_varmap (op, dvs)
199
+ floatT = calculate_float_type (op, typeof (op))
200
+ _u0_eltype = something (u0_eltype, floatT)
201
+
202
+ guess = copy (guesses (sys))
203
+ defs = copy (defaults (sys))
204
+ if A != = nothing
205
+ guess[linear_matrix_param] = fill (NaN , size (A))
206
+ defs[linear_matrix_param] = A
207
+ end
208
+ if B != = nothing
209
+ for (par, mat) in zip (quadratic_forms, B)
210
+ guess[par] = fill (NaN , size (mat))
211
+ defs[par] = mat
212
+ end
213
+ cachelen = jac ? length (dvs) * length (eqs) : length (dvs)
214
+ defs[diffcache_par] = DiffCache (zeros (DiffEqBase. value (_u0_eltype), cachelen))
215
+ end
216
+ @set! sys. guesses = guess
217
+ @set! sys. defaults = defs
218
+
219
+ f, u0,
220
+ p = process_SciMLProblem (SemilinearODEFunction{iip, spec}, sys, op;
221
+ t = tspan != = nothing ? tspan[1 ] : tspan, expression, check_compatibility,
222
+ semiquadratic_form, sparse, u0_eltype, stiff_linear, stiff_quadratic, stiff_nonlinear, jac, kwargs... )
223
+
224
+ kwargs = process_kwargs (sys; expression, callback, kwargs... )
225
+
226
+ args = (; f, u0, tspan, p)
227
+ maybe_codegen_scimlproblem (expression, SplitODEProblem{iip}, args; kwargs... )
228
+ end
229
+
230
+ """
231
+ $(TYPEDSIGNATURES)
232
+
233
+ Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
234
+ `A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
235
+ """
236
+ function add_semiquadratic_parameters (sys:: System , A, B, C)
237
+ eqs = equations (sys)
238
+ n = length (eqs)
239
+ var_to_name = copy (get_var_to_name (sys))
240
+ if B != = nothing
241
+ for i in eachindex (B)
242
+ B[i] === nothing && continue
243
+ par = get_quadratic_form_param ((n, n), i)
244
+ var_to_name[get_quadratic_form_name (i)] = par
245
+ sys = with_additional_constant_parameter (sys, par)
246
+ end
247
+ par = get_diffcache_param (Float64)
248
+ var_to_name[DIFFCACHE_PARAM_NAME] = par
249
+ sys = with_additional_nonnumeric_parameter (sys, par)
250
+ end
251
+ if A != = nothing
252
+ par = get_linear_matrix_param ((n, n))
253
+ var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
254
+ sys = with_additional_constant_parameter (sys, par)
255
+ end
256
+ @set! sys. var_to_name = var_to_name
257
+ if get_parent (sys) != = nothing
258
+ @set! sys. parent = add_semiquadratic_parameters (get_parent (sys), A, B, C)
259
+ end
260
+ return sys
261
+ end
262
+
110
263
function check_compatible_system (
111
264
T:: Union {Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
112
- Type{DAEProblem}, Type{SteadyStateProblem}},
265
+ Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
266
+ Type{SemilinearODEProblem}},
113
267
sys:: System )
114
268
check_time_dependent (sys, T)
115
269
check_not_dde (sys)
0 commit comments