@@ -24,9 +24,11 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
24
24
end
25
25
26
26
internal_ode_kwargs = (; verbose, kwargs... , odesolve_kwargs... , save_end = true )
27
- solve_internal_odes! = (resid_nodes, us, p, cur_nshoot, nodes) -> __multiple_shooting_solve_internal_odes! (resid_nodes,
28
- us, p, Val (iip), f, cur_nshoot, nodes, tspan, u0_size, N, alg, ensemblealg,
29
- internal_ode_kwargs)
27
+ function solve_internal_odes! (resid_nodes:: T1 , us:: T2 , p:: T3 , cur_nshoot:: Int ,
28
+ nodes:: T4 ) where {T1, T2, T3, T4}
29
+ return __multiple_shooting_solve_internal_odes! (resid_nodes, us, p, prob, f,
30
+ cur_nshoot, nodes, tspan, u0_size, N, alg, ensemblealg, internal_ode_kwargs)
31
+ end
30
32
31
33
# This gets all the nshoots except the final SingleShooting case
32
34
all_nshoots = __get_all_nshoots (alg. grid_coarsening, nshoots)
@@ -42,13 +44,13 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
42
44
end
43
45
44
46
if prob. problem_type isa TwoPointBVProblem
45
- __solve_nlproblem! (alg, bcresid_prototype, u_at_nodes, nodes, cur_nshoot, N ,
46
- resida_len, residb_len, Val (iip), solve_internal_odes!, bc[1 ], bc[2 ], prob ,
47
- u0; verbose, kwargs... , nlsolve_kwargs... )
47
+ __solve_nlproblem! (prob . problem_type, alg, bcresid_prototype, u_at_nodes, nodes,
48
+ cur_nshoot, N, resida_len, residb_len, solve_internal_odes!, bc[1 ], bc[2 ],
49
+ prob, u0; verbose, kwargs... , nlsolve_kwargs... )
48
50
else
49
- __solve_nlproblem! (alg, bcresid_prototype, u_at_nodes, nodes, cur_nshoot, N ,
50
- prod (resid_size), Val (iip), solve_internal_odes!, bc, prob, f, u0_size, u0;
51
- verbose, kwargs... , nlsolve_kwargs... )
51
+ __solve_nlproblem! (prob . problem_type, alg, bcresid_prototype, u_at_nodes, nodes,
52
+ cur_nshoot, N, prod (resid_size), solve_internal_odes!, bc, prob, f,
53
+ u0_size, u0; verbose, kwargs... , nlsolve_kwargs... )
52
54
end
53
55
end
54
56
@@ -57,9 +59,9 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
57
59
odesolve_kwargs, nlsolve_kwargs, verbose, kwargs... )
58
60
end
59
61
60
- function __solve_nlproblem! (alg:: MultipleShooting , bcresid_prototype, u_at_nodes, nodes ,
61
- cur_nshoot, N, resida_len, residb_len, iip :: Val , solve_internal_odes! :: S , bca :: B1 ,
62
- bcb:: B2 , prob, u0; kwargs... ) where {B1, B2, S}
62
+ function __solve_nlproblem! (:: TwoPointBVProblem , alg:: MultipleShooting , bcresid_prototype,
63
+ u_at_nodes, nodes, cur_nshoot :: Int , N :: Int , resida_len :: Int , residb_len :: Int ,
64
+ solve_internal_odes! :: S , bca :: B1 , bcb:: B2 , prob, u0; kwargs... ) where {B1, B2, S}
63
65
if __any_sparse_ad (alg. jac_alg)
64
66
J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
65
67
bcresid_prototype, u0, N, cur_nshoot)
@@ -69,7 +71,7 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
69
71
similar (u_at_nodes, cur_nshoot * N), bcresid_prototype[2 ])
70
72
71
73
loss_fn = (du, u, p) -> __multiple_shooting_2point_loss! (du, u, p, cur_nshoot,
72
- nodes, iip , solve_internal_odes!, resida_len, residb_len, N, bca, bcb)
74
+ nodes, prob , solve_internal_odes!, resida_len, residb_len, N, bca, bcb)
73
75
loss_fnₚ = (du, u) -> loss_fn (du, u, prob. p)
74
76
75
77
sd_bvp = alg. jac_alg. diffmode isa AbstractSparseADType ?
@@ -93,8 +95,8 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
93
95
return nothing
94
96
end
95
97
96
- function __solve_nlproblem! (alg:: MultipleShooting , bcresid_prototype, u_at_nodes, nodes ,
97
- cur_nshoot, N, resid_size, iip :: Val , solve_internal_odes!:: S , bc:: BC ,
98
+ function __solve_nlproblem! (:: StandardBVProblem , alg:: MultipleShooting , bcresid_prototype,
99
+ u_at_nodes, nodes, cur_nshoot, N, resid_len :: Int , solve_internal_odes!:: S , bc:: BC ,
98
100
prob, f:: F , u0_size, u0; kwargs... ) where {BC, F, S}
99
101
if __any_sparse_ad (alg. jac_alg)
100
102
J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
@@ -107,18 +109,17 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
107
109
pickchunksize ((cur_nshoot + 1 ) * N), alg. jac_alg. bc_diffmode)
108
110
109
111
loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss! (du, u, p, cur_nshoot,
110
- nodes, iip , solve_internal_odes!, prod (resid_size) , N, f, bc, u0_size,
111
- tspan, alg. ode_alg, u0)
112
+ nodes, prob , solve_internal_odes!, resid_len , N, f, bc, u0_size, prob . tspan ,
113
+ alg. ode_alg, u0)
112
114
113
115
ode_fn = (du, u) -> solve_internal_odes! (du, u, prob. p, cur_nshoot, nodes)
114
116
sd_ode = alg. jac_alg. nonbc_diffmode isa AbstractSparseADType ?
115
117
__sparsity_detection_alg (J_proto) : NoSparsityDetection ()
116
118
ode_jac_cache = sparse_jacobian_cache (alg. jac_alg. nonbc_diffmode, sd_ode,
117
119
ode_fn, similar (u_at_nodes, cur_nshoot * N), u_at_nodes)
118
120
119
- bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc! (du, u, prob. p,
120
- cur_nshoot, nodes, iip, solve_internal_odes!, N, f, bc, u0_size, tspan, alg. ode_alg,
121
- u0)
121
+ bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc! (du, u, prob. p, cur_nshoot, nodes,
122
+ prob, solve_internal_odes!, N, f, bc, u0_size, tspan, alg. ode_alg, u0)
122
123
sd_bc = alg. jac_alg. bc_diffmode isa AbstractSparseADType ?
123
124
SymbolicsSparsityDetection () : NoSparsityDetection ()
124
125
bc_jac_cache = sparse_jacobian_cache (alg. jac_alg. bc_diffmode,
@@ -140,9 +141,10 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
140
141
return nothing
141
142
end
142
143
143
- function __multiple_shooting_solve_internal_odes! (resid_nodes, us, p, :: Val{iip} , f:: F ,
144
+ function __multiple_shooting_solve_internal_odes! (resid_nodes, us, p, _prob , f:: F ,
144
145
cur_nshoots:: Int , nodes, tspan, u0_size, N, alg:: MultipleShooting ,
145
- ensemblealg, kwargs) where {iip, F}
146
+ ensemblealg, kwargs) where {F}
147
+ iip = isinplace (_prob)
146
148
ts_ = Vector {Vector{typeof(first(tspan))}} (undef, cur_nshoots)
147
149
us_ = Vector {Vector{typeof(us)}} (undef, cur_nshoots)
148
150
@@ -171,14 +173,15 @@ function __multiple_shooting_solve_internal_odes!(resid_nodes, us, p, ::Val{iip}
171
173
return reduce (vcat, ensemble_sol. u. us), reduce (vcat, ensemble_sol. u. ts)
172
174
end
173
175
174
- function __multiple_shooting_2point_jacobian! (J, us, p, jac_cache, loss_fn, resid,
175
- alg:: MultipleShooting )
176
+ function __multiple_shooting_2point_jacobian! (J, us, p, jac_cache, loss_fn:: F , resid,
177
+ alg:: MultipleShooting ) where {F}
176
178
sparse_jacobian! (J, alg. jac_alg. diffmode, jac_cache, loss_fn, resid, us)
177
179
return nothing
178
180
end
179
181
180
182
function __multiple_shooting_mpoint_jacobian! (J, us, p, resid_bc, resid_nodes,
181
- ode_jac_cache, bc_jac_cache, ode_fn, bc_fn, alg:: MultipleShooting , N:: Int )
183
+ ode_jac_cache, bc_jac_cache, ode_fn:: F1 , bc_fn:: F2 , alg:: MultipleShooting ,
184
+ N:: Int ) where {F1, F2}
182
185
J_bc = @view (J[1 : N, :])
183
186
J_c = @view (J[(N + 1 ): end , :])
184
187
@@ -190,8 +193,8 @@ function __multiple_shooting_mpoint_jacobian!(J, us, p, resid_bc, resid_nodes,
190
193
end
191
194
192
195
@views function __multiple_shooting_2point_loss! (resid, us, p, cur_nshoots:: Int , nodes,
193
- :: Val{iip} , solve_internal_odes!:: S , resida_len, residb_len, N, bca,
194
- bcb) where {iip, S }
196
+ prob , solve_internal_odes!:: S , resida_len, residb_len, N, bca:: BCA ,
197
+ bcb:: BCB ) where {S, BCA, BCB }
195
198
resid_ = resid[(resida_len + 1 ): (end - residb_len)]
196
199
solve_internal_odes! (resid_, us, p, cur_nshoots, nodes)
197
200
201
204
ua = us[1 : N]
202
205
ub = us[(end - N + 1 ): end ]
203
206
204
- if iip
207
+ if isinplace (prob)
205
208
bca (resid_bc_a, ua, p)
206
209
bcb (resid_bc_b, ub, p)
207
210
else
213
216
end
214
217
215
218
@views function __multiple_shooting_mpoint_loss_bc! (resid_bc, us, p, cur_nshoots:: Int ,
216
- nodes, :: Val{iip} , solve_internal_odes!:: S , N, f, bc, u0_size, tspan,
217
- ode_alg, u0) where {iip, S}
219
+ nodes, prob, solve_internal_odes!:: S , N, f:: F , bc:: BC , u0_size, tspan,
220
+ ode_alg, u0) where {S, F, BC}
221
+ iip = isinplace (prob)
218
222
_resid_nodes = similar (us, cur_nshoots * N)
219
223
220
224
# NOTE: We need to recompute this to correctly propagate the dual numbers / gradients
233
237
end
234
238
235
239
@views function __multiple_shooting_mpoint_loss! (resid, us, p, cur_nshoots:: Int , nodes,
236
- :: Val{iip} , solve_internal_odes!:: S , resid_len, N, f, bc, u0_size, tspan,
237
- ode_alg, u0) where {iip, S}
240
+ prob, solve_internal_odes!:: S , resid_len, N, f:: F , bc:: BC , u0_size, tspan,
241
+ ode_alg, u0) where {S, F, BC}
242
+ iip = isinplace (prob)
238
243
resid_bc = resid[1 : resid_len]
239
244
resid_nodes = resid[(resid_len + 1 ): end ]
240
245
0 commit comments