@@ -68,13 +68,8 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
68
68
resid_prototype = vcat (bcresid_prototype[1 ],
69
69
similar (u_at_nodes, cur_nshoot * N), bcresid_prototype[2 ])
70
70
71
- __resid_nodes = resid_prototype[(resida_len + 1 ): (resida_len + cur_nshoot * N)]
72
- resid_nodes = __maybe_allocate_diffcache (__resid_nodes,
73
- pickchunksize ((cur_nshoot + 1 ) * N), alg. jac_alg. diffmode)
74
-
75
71
loss_fn = (du, u, p) -> __multiple_shooting_2point_loss! (du, u, p, cur_nshoot,
76
- nodes, iip, solve_internal_odes!, resida_len, residb_len, N, bca,
77
- bcb)
72
+ nodes, iip, solve_internal_odes!, resida_len, residb_len, N, bca, bcb)
78
73
loss_fnₚ = (du, u) -> loss_fn (du, u, prob. p)
79
74
80
75
sd_bvp = alg. jac_alg. diffmode isa AbstractSparseADType ?
@@ -113,7 +108,7 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
113
108
114
109
loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss! (du, u, p, cur_nshoot,
115
110
nodes, iip, solve_internal_odes!, prod (resid_size), N, f, bc, u0_size,
116
- tspan, alg. ode_alg)
111
+ tspan, alg. ode_alg, u0 )
117
112
118
113
ode_fn = (du, u) -> solve_internal_odes! (du, u, prob. p, cur_nshoot, nodes)
119
114
sd_ode = alg. jac_alg. nonbc_diffmode isa AbstractSparseADType ?
@@ -122,7 +117,8 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
122
117
ode_fn, similar (u_at_nodes, cur_nshoot * N), u_at_nodes)
123
118
124
119
bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc! (du, u, prob. p,
125
- cur_nshoot, nodes, iip, solve_internal_odes!, N, f, bc, u0_size, tspan, alg. ode_alg)
120
+ cur_nshoot, nodes, iip, solve_internal_odes!, N, f, bc, u0_size, tspan, alg. ode_alg,
121
+ u0)
126
122
sd_bc = alg. jac_alg. bc_diffmode isa AbstractSparseADType ?
127
123
SymbolicsSparsityDetection () : NoSparsityDetection ()
128
124
bc_jac_cache = sparse_jacobian_cache (alg. jac_alg. bc_diffmode,
@@ -144,9 +140,9 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
144
140
return nothing
145
141
end
146
142
147
- function __multiple_shooting_solve_internal_odes! (resid_nodes, us, p, :: Val{iip} , f,
143
+ function __multiple_shooting_solve_internal_odes! (resid_nodes, us, p, :: Val{iip} , f:: F ,
148
144
cur_nshoots:: Int , nodes, tspan, u0_size, N, alg:: MultipleShooting ,
149
- ensemblealg, kwargs) where {iip}
145
+ ensemblealg, kwargs) where {iip, F }
150
146
ts_ = Vector {Vector{typeof(first(tspan))}} (undef, cur_nshoots)
151
147
us_ = Vector {Vector{typeof(us)}} (undef, cur_nshoots)
152
148
@@ -194,7 +190,7 @@ function __multiple_shooting_mpoint_jacobian!(J, us, p, resid_bc, resid_nodes,
194
190
end
195
191
196
192
@views function __multiple_shooting_2point_loss! (resid, us, p, cur_nshoots:: Int , nodes,
197
- :: Val{iip} , solve_internal_odes!, resida_len, residb_len, N, bca, bcb) where {iip}
193
+ :: Val{iip} , solve_internal_odes!:: S , resida_len, residb_len, N, bca, bcb) where {iip, S }
198
194
resid_ = resid[(resida_len + 1 ): (end - residb_len)]
199
195
solve_internal_odes! (resid_, us, p, cur_nshoots, nodes)
200
196
@@ -216,13 +212,14 @@ end
216
212
end
217
213
218
214
@views function __multiple_shooting_mpoint_loss_bc! (resid_bc, us, p, cur_nshoots:: Int ,
219
- nodes, :: Val{iip} , solve_internal_odes!, N, f, bc, u0_size, tspan, ode_alg) where {iip}
215
+ nodes, :: Val{iip} , solve_internal_odes!:: S , N, f, bc, u0_size, tspan,
216
+ ode_alg, u0) where {iip, S}
220
217
_resid_nodes = similar (us, cur_nshoots * N)
221
218
222
219
# NOTE: We need to recompute this to correctly propagate the dual numbers / gradients
223
220
_us, _ts = solve_internal_odes! (_resid_nodes, us, p, cur_nshoots, nodes)
224
221
225
- odeprob = ODEProblem {iip} (f, reshape (us[ 1 : N], u0_size) , tspan, p)
222
+ odeprob = ODEProblem {iip} (f, u0 , tspan, p)
226
223
total_solution = SciMLBase. build_solution (odeprob, ode_alg, _ts, _us)
227
224
228
225
if iip
@@ -235,14 +232,14 @@ end
235
232
end
236
233
237
234
@views function __multiple_shooting_mpoint_loss! (resid, us, p, cur_nshoots:: Int , nodes,
238
- :: Val{iip} , solve_internal_odes!, resid_len, N, f, bc, u0_size, tspan,
239
- ode_alg) where {iip}
235
+ :: Val{iip} , solve_internal_odes!:: S , resid_len, N, f, bc, u0_size, tspan,
236
+ ode_alg, u0 ) where {iip, S }
240
237
resid_bc = resid[1 : resid_len]
241
238
resid_nodes = resid[(resid_len + 1 ): end ]
242
239
243
240
_us, _ts = solve_internal_odes! (resid_nodes, us, p, cur_nshoots, nodes)
244
241
245
- odeprob = ODEProblem {iip} (f, reshape (us[ 1 : N], u0_size) , tspan, p)
242
+ odeprob = ODEProblem {iip} (f, u0 , tspan, p)
246
243
total_solution = SciMLBase. build_solution (odeprob, ode_alg, _ts, _us)
247
244
248
245
if iip
0 commit comments