25
25
fᵢ₂_cache
26
26
defect
27
27
new_stages
28
+ resid_size
28
29
kwargs
29
30
end
30
31
@@ -64,8 +65,13 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
64
65
bcresid_prototype, resid₁_size = __get_bcresid_prototype (prob. problem_type, prob, X)
65
66
66
67
residual = if iip
67
- vcat ([__alloc_diffcache (bcresid_prototype)],
68
- __alloc_diffcache .(copy .(@view (y₀[2 : end ]))))
68
+ if prob. problem_type isa TwoPointBVProblem
69
+ vcat ([__alloc_diffcache (__vec (bcresid_prototype))],
70
+ __alloc_diffcache .(copy .(@view (y₀[2 : end ]))))
71
+ else
72
+ vcat ([__alloc_diffcache (bcresid_prototype)],
73
+ __alloc_diffcache .(copy .(@view (y₀[2 : end ]))))
74
+ end
69
75
else
70
76
nothing
71
77
end
@@ -74,6 +80,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
74
80
new_stages = [similar (X, ifelse (adaptive, M, 0 )) for _ in 1 : n]
75
81
76
82
# Transform the functions to handle non-vector inputs
83
+ bcresid_prototype = __vec (bcresid_prototype)
77
84
f, bc = if X isa AbstractVector
78
85
prob. f, prob. f. bc
79
86
elseif iip
@@ -92,7 +99,6 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
92
99
end
93
100
(__vecbc_a!, __vecbc_b!)
94
101
end
95
- bcresid_prototype = vec (bcresid_prototype)
96
102
vecf!, vecbc!
97
103
else
98
104
vecf (u, p, t) = vec (prob. f (reshape (u, size (X)), p, t))
@@ -103,14 +109,13 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
103
109
__vecbc_b (ub, p) = vec (prob. f. bc[2 ](reshape (ub, size (X)), p))
104
110
(__vecbc_a, __vecbc_b)
105
111
end
106
- bcresid_prototype = vec (bcresid_prototype)
107
112
vecf, vecbc
108
113
end
109
114
110
115
return MIRKCache {iip, T} (alg_order (alg), stage, M, size (X), f, bc, prob,
111
116
prob. problem_type, prob. p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt,
112
117
k_discrete, k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages,
113
- (; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs... ))
118
+ resid₁_size, (; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs... ))
114
119
end
115
120
116
121
"""
@@ -224,13 +229,21 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {
224
229
end
225
230
226
231
loss = if iip
227
- function loss_internal! (resid:: AbstractVector , u:: AbstractVector , p = cache. p)
232
+ @views function loss_internal! (resid:: AbstractVector ,
233
+ u:: AbstractVector ,
234
+ p = cache. p)
228
235
y_ = recursive_unflatten! (cache. y, u)
229
236
resids = [get_tmp (r, u) for r in cache. residual]
230
- eval_bc_residual! (resids[1 ], cache. problem_type, cache. bc, y_, p, cache. mesh)
237
+ resid_bc = if cache. problem_type isa TwoPointBVProblem
238
+ (resids[1 ][1 : prod (cache. resid_size[1 ])],
239
+ resids[1 ][(prod (cache. resid_size[1 ]) + 1 ): end ])
240
+ else
241
+ resids[1 ]
242
+ end
243
+ eval_bc_residual! (resid_bc, cache. problem_type, cache. bc, y_, p, cache. mesh)
231
244
Φ! (resids[2 : end ], cache, y_, u, p)
232
245
if cache. problem_type isa TwoPointBVProblem
233
- recursive_flatten_twopoint! (resid, resids)
246
+ recursive_flatten_twopoint! (resid, resids, cache . resid_size )
234
247
else
235
248
recursive_flatten! (resid, resids)
236
249
end
@@ -242,7 +255,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {
242
255
resid_bc = eval_bc_residual (cache. problem_type, cache. bc, y_, p, cache. mesh)
243
256
resid_co = Φ (cache, y_, u, p)
244
257
if cache. problem_type isa TwoPointBVProblem
245
- return vcat (resid_bc. x [1 ], mapreduce (vec, vcat, resid_co), resid_bc. x [2 ])
258
+ return vcat (resid_bc[1 ], mapreduce (vec, vcat, resid_co), resid_bc[2 ])
246
259
else
247
260
return vcat (resid_bc, mapreduce (vec, vcat, resid_co))
248
261
end
@@ -268,7 +281,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
268
281
269
282
sd_collocation = if jac_alg. nonbc_diffmode isa AbstractSparseADType
270
283
PrecomputedJacobianColorvec (__generate_sparse_jacobian_prototype (cache,
271
- cache. problem_type, y, cache. M, N))
284
+ cache. problem_type, y, y, cache. M, N))
272
285
else
273
286
NoSparsityDetection ()
274
287
end
@@ -299,19 +312,20 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
299
312
return NonlinearProblem (NonlinearFunction {iip} (loss; jac, jac_prototype), y, cache. p)
300
313
end
301
314
302
- function __construct_nlproblem (cache:: MIRKCache{iip} , y, loss_bc, loss_collocation, loss,
303
- :: TwoPointBVProblem ) where {iip}
315
+ function __construct_nlproblem (cache:: MIRKCache{iip} , y, loss_bc, loss_collocation,
316
+ loss, :: TwoPointBVProblem ) where {iip}
304
317
@unpack nlsolve, jac_alg = cache. alg
305
318
N = length (cache. mesh)
306
319
307
- resid = ArrayPartition (cache. bcresid_prototype, similar (y, cache. M * (N - 1 )))
320
+ resid = vcat (cache. bcresid_prototype[1 : prod (cache. resid_size[1 ])],
321
+ similar (y, cache. M * (N - 1 )),
322
+ cache. bcresid_prototype[(prod (cache. resid_size[1 ]) + 1 ): end ])
308
323
309
- # TODO : We can splitup the computation here as well similar to the Multiple Shooting
310
- # TODO : code. That way for the BC part the actual jacobian computation is even cheaper
311
- # TODO : Remember to not reorder if we end up using that implementation
312
324
sd = if jac_alg. diffmode isa AbstractSparseADType
313
325
PrecomputedJacobianColorvec (__generate_sparse_jacobian_prototype (cache,
314
- cache. problem_type, resid. x[1 ], cache. M, N))
326
+ cache. problem_type, @view (cache. bcresid_prototype[1 : prod (cache. resid_size[1 ])]),
327
+ @view (cache. bcresid_prototype[(prod (cache. resid_size[1 ]) + 1 ): end ]), cache. M,
328
+ N))
315
329
else
316
330
NoSparsityDetection ()
317
331
end
0 commit comments