@@ -212,7 +212,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {
212
212
end
213
213
214
214
function __mirk_loss! (resid, u, p, y, pt:: StandardBVProblem , bc!:: BC , residual, mesh,
215
- cache) where {BC <: Function }
215
+ cache) where {BC}
216
216
y_ = recursive_unflatten! (y, u)
217
217
resids = [get_tmp (r, u) for r in residual]
218
218
eval_bc_residual! (resids[1 ], pt, bc!, y_, p, mesh)
@@ -222,7 +222,7 @@ function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual,
222
222
end
223
223
224
224
function __mirk_loss! (resid, u, p, y, pt:: TwoPointBVProblem , bc!:: Tuple{BC1, BC2} , residual,
225
- mesh, cache) where {BC1 <: Function , BC2 <: Function }
225
+ mesh, cache) where {BC1, BC2}
226
226
y_ = recursive_unflatten! (y, u)
227
227
resids = [get_tmp (r, u) for r in residual]
228
228
resida = @view resids[1 ][1 : prod (cache. resid_size[1 ])]
@@ -233,29 +233,28 @@ function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2
233
233
return nothing
234
234
end
235
235
236
- function __mirk_loss (u, p, y, pt:: StandardBVProblem , bc:: BC , mesh,
237
- cache) where {BC <: Function }
236
+ function __mirk_loss (u, p, y, pt:: StandardBVProblem , bc:: BC , mesh, cache) where {BC}
238
237
y_ = recursive_unflatten! (y, u)
239
238
resid_bc = eval_bc_residual (pt, bc, y_, p, mesh)
240
239
resid_co = Φ (cache, y_, u, p)
241
240
return vcat (resid_bc, mapreduce (vec, vcat, resid_co))
242
241
end
243
242
244
243
function __mirk_loss (u, p, y, pt:: TwoPointBVProblem , bc:: Tuple{BC1, BC2} , mesh,
245
- cache) where {BC1 <: Function , BC2 <: Function }
244
+ cache) where {BC1, BC2}
246
245
y_ = recursive_unflatten! (y, u)
247
246
resid_bca, resid_bcb = eval_bc_residual (pt, bc, y_, p, mesh)
248
247
resid_co = Φ (cache, y_, u, p)
249
248
return vcat (resid_bca, mapreduce (vec, vcat, resid_co), resid_bcb)
250
249
end
251
250
252
- function __mirk_loss_bc! (resid, u, p, pt, bc!:: BC , y, mesh) where {BC <: Function }
251
+ function __mirk_loss_bc! (resid, u, p, pt, bc!:: BC , y, mesh) where {BC}
253
252
y_ = recursive_unflatten! (y, u)
254
253
eval_bc_residual! (resid, pt, bc!, y_, p, mesh)
255
254
return nothing
256
255
end
257
256
258
- function __mirk_loss_bc (u, p, pt, bc!:: BC , y, mesh) where {BC <: Function }
257
+ function __mirk_loss_bc (u, p, pt, bc!:: BC , y, mesh) where {BC}
259
258
y_ = recursive_unflatten! (y, u)
260
259
return eval_bc_residual (pt, bc!, y_, p, mesh)
261
260
end
@@ -275,7 +274,7 @@ function __mirk_loss_collocation(u, p, y, mesh, residual, cache)
275
274
end
276
275
277
276
function __construct_nlproblem (cache:: MIRKCache{iip} , y, loss_bc:: BC , loss_collocation:: C ,
278
- loss:: L , :: StandardBVProblem ) where {iip, BC <: Function , C <: Function , L <: Function }
277
+ loss:: L , :: StandardBVProblem ) where {iip, BC, C, L}
279
278
@unpack nlsolve, jac_alg = cache. alg
280
279
N = length (cache. mesh)
281
280
@@ -317,24 +316,23 @@ end
317
316
318
317
function __mirk_mpoint_jacobian! (J, x, p, bc_diffmode, nonbc_diffmode, bc_diffcache,
319
318
nonbc_diffcache, loss_bc:: BC , loss_collocation:: C , resid_bc, resid_collocation,
320
- M:: Int ) where {BC <: Function , C <: Function }
319
+ M:: Int ) where {BC, C}
321
320
sparse_jacobian! (@view (J[1 : M, :]), bc_diffmode, bc_diffcache, loss_bc, resid_bc, x)
322
321
sparse_jacobian! (@view (J[(M + 1 ): end , :]), nonbc_diffmode, nonbc_diffcache,
323
322
loss_collocation, resid_collocation, x)
324
323
return nothing
325
324
end
326
325
327
326
function __mirk_mpoint_jacobian (x, p, J, bc_diffmode, nonbc_diffmode, bc_diffcache,
328
- nonbc_diffcache, loss_bc:: BC , loss_collocation:: C ,
329
- M:: Int ) where {BC <: Function , C <: Function }
327
+ nonbc_diffcache, loss_bc:: BC , loss_collocation:: C , M:: Int ) where {BC, C}
330
328
sparse_jacobian! (@view (J[1 : M, :]), bc_diffmode, bc_diffcache, loss_bc, x)
331
329
sparse_jacobian! (@view (J[(M + 1 ): end , :]), nonbc_diffmode, nonbc_diffcache,
332
330
loss_collocation, x)
333
331
return J
334
332
end
335
333
336
334
function __construct_nlproblem (cache:: MIRKCache{iip} , y, loss_bc:: BC , loss_collocation:: C ,
337
- loss:: L , :: TwoPointBVProblem ) where {iip, BC <: Function , C <: Function , L <: Function }
335
+ loss:: L , :: TwoPointBVProblem ) where {iip, BC, C, L}
338
336
@unpack nlsolve, jac_alg = cache. alg
339
337
N = length (cache. mesh)
340
338
@@ -366,14 +364,12 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
366
364
return NonlinearProblem (NonlinearFunction {iip} (loss; jac, jac_prototype), y, cache. p)
367
365
end
368
366
369
- function __mirk_2point_jacobian! (J, x, p, diffmode, diffcache, loss_fn:: L ,
370
- resid) where {L <: Function }
367
+ function __mirk_2point_jacobian! (J, x, p, diffmode, diffcache, loss_fn:: L , resid) where {L}
371
368
sparse_jacobian! (J, diffmode, diffcache, loss_fn, resid, x)
372
369
return J
373
370
end
374
371
375
- function __mirk_2point_jacobian (x, p, J, diffmode, diffcache,
376
- loss_fn:: L ) where {L <: Function }
372
+ function __mirk_2point_jacobian (x, p, J, diffmode, diffcache, loss_fn:: L ) where {L}
377
373
sparse_jacobian! (J, diffmode, diffcache, loss_fn, x)
378
374
return J
379
375
end
0 commit comments