@@ -107,53 +107,163 @@ function SciMLBase.solve!(cache::MIRKNCache{iip, T}) where {iip, T}
107
107
end
108
108
109
109
function __perform_mirkn_iteration (cache:: MIRKNCache ; nlsolve_kwargs = (;), kwargs... )
110
- nlprob:: NonlinearProblem = __construct_nlproblem (cache, vec (cache. y₀))
110
+ nlprob:: NonlinearProblem = __construct_nlproblem (cache, vec (cache. y₀), copy (cache . y₀) )
111
111
nlsolve_alg = __concrete_nonlinearsolve_algorithm (nlprob, cache. alg. nlsolve)
112
112
sol_nlprob = __solve (nlprob, nlsolve_alg; kwargs... , nlsolve_kwargs... , alias_u0 = true )
113
113
recursive_unflatten! (cache. y₀, sol_nlprob. u)
114
114
115
115
return sol_nlprob, sol_nlprob. retcode
116
116
end
117
117
118
- function __construct_nlproblem (cache :: MIRKNCache{iip} , y) where {iip}
119
- (; jac_alg) = cache . alg
120
- (; diffmode) = jac_alg
118
+ # Constructing the Nonlinear Problem
119
+ function __construct_nlproblem (
120
+ cache :: MIRKNCache{iip} , y :: AbstractVector , y₀ :: AbstractVectorOfArray ) where {iip}
121
121
pt = cache. problem_type
122
122
123
+ eval_sol = EvalSol (
124
+ __restructure_sol (y₀. u[1 : length (cache. mesh)], cache. in_size), cache. mesh, cache)
125
+ eval_dsol = EvalSol (
126
+ __restructure_sol (y₀. u[(length (cache. mesh) + 1 ): end ], cache. in_size),
127
+ cache. mesh, cache)
128
+
129
+ loss_bc = if iip
130
+ @closure (du, u, p) -> __mirkn_loss_bc! (
131
+ du, u, p, pt, cache. bc, cache. y, cache. mesh, cache)
132
+ else
133
+ @closure (u, p) -> __mirkn_loss_bc (u, p, pt, cache. bc, cache. y, cache. mesh, cache)
134
+ end
135
+
136
+ loss_collocation = if iip
137
+ @closure (du, u, p) -> __mirkn_loss_collocation! (
138
+ du, u, p, cache. y, cache. mesh, cache. residual, cache)
139
+ else
140
+ @closure (u, p) -> __mirkn_loss_collocation (
141
+ u, p, cache. y, cache. mesh, cache. residual, cache)
142
+ end
143
+
123
144
loss = if iip
124
145
@closure (du, u, p) -> __mirkn_loss! (
125
- du, u, p, cache. y, pt, cache. bc, cache. residual, cache. mesh, cache)
146
+ du, u, p, cache. y, pt, cache. bc, cache. residual,
147
+ cache. mesh, cache, eval_sol, eval_dsol)
126
148
else
127
- @closure (u, p) -> __mirkn_loss (u, p, cache. y, pt, cache. bc, cache. mesh, cache)
149
+ @closure (u, p) -> __mirkn_loss (
150
+ u, p, cache. y, pt, cache. bc, cache. mesh, cache, eval_sol, eval_dsol)
128
151
end
129
152
130
- resid_prototype = __similar (y)
153
+ return __construct_nlproblem (cache, y, loss_bc, loss_collocation, loss, pt)
154
+ end
155
+
156
+ function __construct_nlproblem (cache:: MIRKNCache{iip} , y, loss_bc:: BC , loss_collocation:: C ,
157
+ loss:: LF , :: StandardSecondOrderBVProblem ) where {iip, BC, C, LF}
158
+ (; jac_alg) = cache. alg
159
+ N = length (cache. mesh)
131
160
132
- jac_cache = if iip
133
- DI. prepare_jacobian (loss, resid_prototype, diffmode, y, Constant (cache. p))
161
+ resid_bc = cache. bcresid_prototype
162
+ L = length (resid_bc)
163
+ resid_collocation = __similar (y, cache. M * (2 * N - 2 ))
164
+
165
+ bc_diffmode = if jac_alg. bc_diffmode isa AutoSparse
166
+ AutoSparse (get_dense_ad (jac_alg. bc_diffmode);
167
+ sparsity_detector = __default_sparsity_detector (jac_alg. bc_diffmode),
168
+ coloring_algorithm = __default_coloring_algorithm (jac_alg. bc_diffmode))
134
169
else
135
- DI . prepare_jacobian (loss, diffmode, y, Constant (cache . p))
170
+ jac_alg . bc_diffmode
136
171
end
137
172
138
- jac_prototype = if iip
139
- DI. jacobian (loss, resid_prototype, jac_cache, diffmode, y, Constant (cache. p))
173
+ cache_bc = if iip
174
+ DI. prepare_jacobian (loss_bc, resid_bc, bc_diffmode, y, Constant (cache. p))
175
+ else
176
+ DI. prepare_jacobian (loss_bc, bc_diffmode, y, Constant (cache. p))
177
+ end
178
+
179
+ nonbc_diffmode = if jac_alg. nonbc_diffmode isa AutoSparse
180
+ AutoSparse (get_dense_ad (jac_alg. nonbc_diffmode);
181
+ sparsity_detector = __default_sparsity_detector (jac_alg. nonbc_diffmode),
182
+ coloring_algorithm = __default_coloring_algorithm (jac_alg. nonbc_diffmode))
140
183
else
141
- DI . jacobian (loss, jac_cache, diffmode, y, Constant (cache . p))
184
+ jac_alg . nonbc_diffmode
142
185
end
143
186
187
+ cache_collocation = if iip
188
+ DI. prepare_jacobian (
189
+ loss_collocation, resid_collocation, nonbc_diffmode, y, Constant (cache. p))
190
+ else
191
+ DI. prepare_jacobian (loss_collocation, nonbc_diffmode, y, Constant (cache. p))
192
+ end
193
+
194
+ J_bc = if iip
195
+ DI. jacobian (loss_bc, resid_bc, cache_bc, bc_diffmode, y, Constant (cache. p))
196
+ else
197
+ DI. jacobian (loss_bc, cache_bc, bc_diffmode, y, Constant (cache. p))
198
+ end
199
+ J_c = if iip
200
+ DI. jacobian (loss_collocation, resid_collocation, cache_collocation,
201
+ nonbc_diffmode, y, Constant (cache. p))
202
+ else
203
+ DI. jacobian (
204
+ loss_collocation, cache_collocation, nonbc_diffmode, y, Constant (cache. p))
205
+ end
206
+
207
+ jac_prototype = vcat (J_bc, J_c)
208
+
144
209
jac = if iip
145
210
@closure (J, u, p) -> __mirkn_mpoint_jacobian! (
146
- J, u, diffmode, jac_cache, loss, resid_prototype, cache. p)
211
+ J, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation,
212
+ loss_bc, loss_collocation, resid_bc, resid_collocation, L, cache. p)
147
213
else
148
214
@closure (u, p) -> __mirkn_mpoint_jacobian (
149
- jac_prototype, u, diffmode, jac_cache, loss, cache. p)
215
+ jac_prototype, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc,
216
+ cache_collocation, loss_bc, loss_collocation, L, cache. p)
150
217
end
151
-
218
+ resid_prototype = vcat (resid_bc, resid_collocation)
152
219
nlf = NonlinearFunction {iip} (
153
220
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
154
221
__internal_nlsolve_problem (cache. prob, resid_prototype, y, nlf, y, cache. p)
155
222
end
156
223
224
+ function __construct_nlproblem (cache:: MIRKNCache{iip} , y, loss_bc:: BC , loss_collocation:: C ,
225
+ loss:: LF , :: TwoPointSecondOrderBVProblem ) where {iip, BC, C, LF}
226
+ (; nlsolve, jac_alg) = cache. alg
227
+ N = length (cache. mesh)
228
+
229
+ resid = vcat (@view (cache. bcresid_prototype[1 : prod (cache. resid_size[1 ])]),
230
+ __similar (y, cache. M * 2 * (N - 1 )),
231
+ @view (cache. bcresid_prototype[(prod (cache. resid_size[1 ]) + 1 ): end ]))
232
+
233
+ diffmode = if jac_alg. diffmode isa AutoSparse
234
+ AutoSparse (get_dense_ad (jac_alg. diffmode);
235
+ sparsity_detector = __default_sparsity_detector (jac_alg. diffmode),
236
+ coloring_algorithm = __default_coloring_algorithm (jac_alg. diffmode))
237
+ else
238
+ jac_alg. diffmode
239
+ end
240
+
241
+ diffcache = if iip
242
+ DI. prepare_jacobian (loss, resid, diffmode, y, Constant (cache. p))
243
+ else
244
+ DI. prepare_jacobian (loss, diffmode, y, Constant (cache. p))
245
+ end
246
+
247
+ jac_prototype = if iip
248
+ DI. jacobian (loss, resid, diffcache, diffmode, y, Constant (cache. p))
249
+ else
250
+ DI. jacobian (loss, diffcache, diffmode, y, Constant (cache. p))
251
+ end
252
+
253
+ jac = if iip
254
+ @closure (J, u, p) -> __mirkn_2point_jacobian! (
255
+ J, u, jac_alg. diffmode, diffcache, loss, resid, p)
256
+ else
257
+ @closure (u, p) -> __mirkn_2point_jacobian (
258
+ u, jac_prototype, jac_alg. diffmode, diffcache, loss, p)
259
+ end
260
+
261
+ resid_prototype = copy (resid)
262
+ nlf = NonlinearFunction {iip} (
263
+ loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
264
+ return __internal_nlsolve_problem (cache. prob, resid_prototype, y, nlf, y, cache. p)
265
+ end
266
+
157
267
function __mirkn_2point_jacobian! (J, x, diffmode, diffcache, loss_fn:: L , resid, p) where {L}
158
268
DI. jacobian! (loss_fn, resid, J, diffcache, diffmode, x, Constant (p))
159
269
return J
@@ -164,44 +274,85 @@ function __mirkn_2point_jacobian(x, J, diffmode, diffcache, loss_fn::L, p) where
164
274
return J
165
275
end
166
276
167
- function __mirkn_mpoint_jacobian! (J, x, diffmode, diffcache, loss, resid, p)
168
- DI. jacobian! (loss, resid, J, diffcache, diffmode, x, Constant (p))
277
+ function __mirkn_mpoint_jacobian! (
278
+ J, _, x, bc_diffmode, nonbc_diffmode, bc_diffcache, nonbc_diffcache, loss_bc:: BC ,
279
+ loss_collocation:: C , resid_bc, resid_collocation, L:: Int , p) where {BC, C}
280
+ DI. jacobian! (
281
+ loss_bc, resid_bc, @view (J[1 : L, :]), bc_diffcache, bc_diffmode, x, Constant (p))
282
+ DI. jacobian! (loss_collocation, resid_collocation, @view (J[(L + 1 ): end , :]),
283
+ nonbc_diffcache, nonbc_diffmode, x, Constant (p))
169
284
return nothing
170
285
end
171
286
172
- function __mirkn_mpoint_jacobian (J, x, diffmode, diffcache, loss, p)
173
- DI. jacobian! (loss, J, diffcache, diffmode, x, Constant (p))
287
+ function __mirkn_mpoint_jacobian (
288
+ J, _, x, bc_diffmode, nonbc_diffmode, bc_diffcache, nonbc_diffcache,
289
+ loss_bc:: BC , loss_collocation:: C , L:: Int , p) where {BC, C}
290
+ DI. jacobian! (loss_bc, @view (J[1 : L, :]), bc_diffcache, bc_diffmode, x, Constant (p))
291
+ DI. jacobian! (loss_collocation, @view (J[(L + 1 ): end , :]),
292
+ nonbc_diffcache, nonbc_diffmode, x, Constant (p))
174
293
return J
175
294
end
176
295
177
- @views function __mirkn_loss! (resid, u, p, y, pt:: StandardSecondOrderBVProblem ,
178
- bc :: BC , residual, mesh, cache:: MIRKNCache ) where {BC}
296
+ @views function __mirkn_loss! (resid, u, p, y, pt:: StandardSecondOrderBVProblem , bc :: BC ,
297
+ residual, mesh, cache:: MIRKNCache , EvalSol, EvalDSol ) where {BC}
179
298
y_ = recursive_unflatten! (y, u)
180
299
resids = [get_tmp (r, u) for r in residual]
181
300
Φ! (resids[3 : end ], cache, y_, u, p)
182
- soly_ = EvalSol (
183
- __restructure_sol (y_ [1 : length (cache . mesh)], cache . in_size), cache. mesh, cache)
184
- dsoly_ = EvalSol ( __restructure_sol (y_[(length (cache. mesh) + 1 ): end ], cache. in_size),
185
- cache. mesh, cache)
186
- eval_bc_residual! (resids[1 : 2 ], pt, bc, soly_, dsoly_ , p, mesh)
301
+ EvalSol . u[ 1 : end ] . = __restructure_sol (y_[ 1 : length (cache . mesh)], cache . in_size)
302
+ EvalSol . cache . k_discrete [1 : end ] . = cache. k_discrete
303
+ EvalDSol . u[ 1 : end ] . = __restructure_sol (y_[(length (cache. mesh) + 1 ): end ], cache. in_size)
304
+ EvalDSol . cache. k_discrete[ 1 : end ] . = cache. k_discrete
305
+ eval_bc_residual! (resids[1 : 2 ], pt, bc, EvalSol, EvalDSol , p, mesh)
187
306
recursive_flatten! (resid, resids)
188
307
return nothing
189
308
end
190
309
191
- @views function __mirkn_loss (u, p, y, pt:: StandardSecondOrderBVProblem ,
192
- bc :: BC , mesh, cache:: MIRKNCache ) where {BC}
310
+ @views function __mirkn_loss (u, p, y, pt:: StandardSecondOrderBVProblem , bc :: BC ,
311
+ mesh, cache:: MIRKNCache , EvalSol, EvalDSol ) where {BC}
193
312
y_ = recursive_unflatten! (y, u)
194
313
resid_co = Φ (cache, y_, u, p)
195
- soly_ = EvalSol (
196
- __restructure_sol (y_[1 : length (cache. mesh)], cache. in_size), cache. mesh, cache)
314
+ EvalSol. u[1 : end ] .= __restructure_sol (y_[1 : length (cache. mesh)], cache. in_size)
315
+ EvalSol. cache. k_discrete[1 : end ] .= cache. k_discrete
316
+ EvalDSol. u[1 : end ] .= __restructure_sol (y_[(length (cache. mesh) + 1 ): end ], cache. in_size)
317
+ EvalDSol. cache. k_discrete[1 : end ] .= cache. k_discrete
318
+ resid_bc = eval_bc_residual (pt, bc, EvalSol, EvalDSol, p, mesh)
319
+ return vcat (resid_bc, mapreduce (vec, vcat, resid_co))
320
+ end
321
+
322
+ @views function __mirkn_loss_bc! (
323
+ resid, u, p, pt, bc!:: BC , y, mesh, cache:: MIRKNCache ) where {BC}
324
+ y_ = recursive_unflatten! (y, u)
325
+ soly_ = EvalSol (__restructure_sol (y_[1 : length (cache. mesh)], cache. in_size), mesh, cache)
197
326
dsoly_ = EvalSol (__restructure_sol (y_[(length (cache. mesh) + 1 ): end ], cache. in_size),
198
327
cache. mesh, cache)
199
- resid_bc = eval_bc_residual (pt, bc, soly_, dsoly_, p, mesh)
200
- return vcat (resid_bc, mapreduce (vec, vcat, resid_co))
328
+ eval_bc_residual! (resid, pt, bc!, soly_, dsoly_, p, mesh)
329
+ return nothing
330
+ end
331
+
332
+ @views function __mirkn_loss_bc (u, p, pt, bc!:: BC , y, mesh, cache:: MIRKNCache ) where {BC}
333
+ y_ = recursive_unflatten! (y, u)
334
+ soly_ = EvalSol (__restructure_sol (y_[1 : length (cache. mesh)], cache. in_size), mesh, cache)
335
+ dsoly_ = EvalSol (__restructure_sol (y_[(length (cache. mesh) + 1 ): end ], cache. in_size),
336
+ cache. mesh, cache)
337
+ return eval_bc_residual (pt, bc!, soly_, dsoly_, p, mesh)
338
+ end
339
+
340
+ @views function __mirkn_loss_collocation! (resid, u, p, y, mesh, residual, cache)
341
+ y_ = recursive_unflatten! (y, u)
342
+ resids = [get_tmp (r, u) for r in residual[3 : end ]]
343
+ Φ! (resids, cache, y_, u, p)
344
+ recursive_flatten! (resid, resids)
345
+ return nothing
346
+ end
347
+
348
+ @views function __mirkn_loss_collocation (u, p, y, mesh, residual, cache)
349
+ y_ = recursive_unflatten! (y, u)
350
+ resids = Φ (cache, y_, u, p)
351
+ return mapreduce (vec, vcat, resids)
201
352
end
202
353
203
- @views function __mirkn_loss! (resid, u, p, y, pt:: TwoPointSecondOrderBVProblem ,
204
- bc! :: BC , residual, mesh, cache:: MIRKNCache ) where {BC}
354
+ @views function __mirkn_loss! (resid, u, p, y, pt:: TwoPointSecondOrderBVProblem , bc! :: BC ,
355
+ residual, mesh, cache:: MIRKNCache , _, _ ) where {BC}
205
356
y_ = recursive_unflatten! (y, u)
206
357
resids = [get_tmp (r, u) for r in residual]
207
358
soly_ = VectorOfArray (y_)
212
363
end
213
364
214
365
@views function __mirkn_loss (u, p, y, pt:: TwoPointSecondOrderBVProblem ,
215
- bc!:: BC , mesh, cache:: MIRKNCache ) where {BC}
366
+ bc!:: BC , mesh, cache:: MIRKNCache , _, _ ) where {BC}
216
367
y_ = recursive_unflatten! (y, u)
217
368
soly_ = VectorOfArray (y_)
218
369
resid_co = Φ (cache, y_, u, p)
0 commit comments