@@ -2,6 +2,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
2
2
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads (), verbose = true , kwargs... )
3
3
@unpack f, tspan = prob
4
4
5
+ @assert (ensemblealg isa EnsembleSerial)|| (ensemblealg isa EnsembleThreads) " Currently MultipleShooting only supports `EnsembleSerial` and `EnsembleThreads`!"
6
+
5
7
ig, T, N, Nig, u0 = __extract_problem_details (prob; dt = 0.1 )
6
8
has_initial_guess = _unwrap_val (ig)
7
9
@@ -27,33 +29,40 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
27
29
end
28
30
29
31
internal_ode_kwargs = (; verbose, kwargs... , odesolve_kwargs... , save_end = true )
32
+
30
33
function solve_internal_odes! (resid_nodes:: T1 , us:: T2 , p:: T3 , cur_nshoot:: Int ,
31
- nodes:: T4 ) where {T1, T2, T3, T4}
32
- return __multiple_shooting_solve_internal_odes! (resid_nodes, us, p, prob, f ,
33
- cur_nshoot , nodes, tspan, u0_size, N, alg, ensemblealg, internal_ode_kwargs )
34
+ nodes:: T4 , odecache :: C ) where {T1, T2, T3, T4, C }
35
+ return __multiple_shooting_solve_internal_odes! (resid_nodes, us, cur_nshoot ,
36
+ odecache , nodes, u0_size, N, ensemblealg)
34
37
end
35
38
36
39
# This gets all the nshoots except the final SingleShooting case
37
40
all_nshoots = __get_all_nshoots (alg. grid_coarsening, nshoots)
38
41
u_at_nodes, nodes = similar (u0, 0 ), typeof (first (tspan))[]
39
42
43
+ ode_cache_loss_fn = __multiple_shooting_init_odecache (ensemblealg, prob,
44
+ alg. ode_alg, u0, maximum (all_nshoots); internal_ode_kwargs... )
45
+
40
46
for (i, cur_nshoot) in enumerate (all_nshoots)
41
47
if i == 1
42
- u_at_nodes = __multiple_shooting_initialize! (nodes, prob, alg, ig, nshoots;
43
- kwargs... , verbose, odesolve_kwargs... )
48
+ u_at_nodes = __multiple_shooting_initialize! (nodes, prob, alg, ig, nshoots,
49
+ ode_cache_loss_fn; kwargs... , verbose, odesolve_kwargs... )
44
50
else
45
51
u_at_nodes = __multiple_shooting_initialize! (nodes, u_at_nodes, prob, alg,
46
- cur_nshoot, all_nshoots[i - 1 ], ig; kwargs... , verbose, odesolve_kwargs... )
52
+ cur_nshoot, all_nshoots[i - 1 ], ig, ode_cache_loss_fn; kwargs... , verbose,
53
+ odesolve_kwargs... )
47
54
end
48
55
49
56
if prob. problem_type isa TwoPointBVProblem
50
57
__solve_nlproblem! (prob. problem_type, alg, bcresid_prototype, u_at_nodes, nodes,
51
- cur_nshoot, N, resida_len, residb_len, solve_internal_odes!, bc[1 ], bc[2 ],
52
- prob, u0, M; verbose, kwargs... , nlsolve_kwargs... )
58
+ cur_nshoot, M, N, resida_len, residb_len, solve_internal_odes!, bc[1 ],
59
+ bc[2 ], prob, u0, ode_cache_loss_fn, ensemblealg, internal_ode_kwargs;
60
+ verbose, kwargs... , nlsolve_kwargs... )
53
61
else
54
62
__solve_nlproblem! (prob. problem_type, alg, bcresid_prototype, u_at_nodes, nodes,
55
- cur_nshoot, N, prod (resid_size), solve_internal_odes!, bc, prob, f,
56
- u0_size, u0, M; verbose, kwargs... , nlsolve_kwargs... )
63
+ cur_nshoot, M, N, prod (resid_size), solve_internal_odes!, bc, prob, f,
64
+ u0_size, u0, ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; verbose,
65
+ kwargs... , nlsolve_kwargs... )
57
66
end
58
67
end
59
68
@@ -62,9 +71,15 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
62
71
odesolve_kwargs, nlsolve_kwargs, verbose, kwargs... )
63
72
end
64
73
74
+ # TODO : We can save even more memory by hoisting the preallocated caches for the ODEs
75
+ # TODO : out of the `__solve_nlproblem!` function and into the `__solve` function.
76
+ # TODO : But we can do it another day. Currently the gains here are quite high to justify
77
+ # TODO : waiting.
78
+
65
79
function __solve_nlproblem! (:: TwoPointBVProblem , alg:: MultipleShooting , bcresid_prototype,
66
- u_at_nodes, nodes, cur_nshoot:: Int , N:: Int , resida_len:: Int , residb_len:: Int ,
67
- solve_internal_odes!:: S , bca:: B1 , bcb:: B2 , prob, u0, M; kwargs... ) where {B1, B2, S}
80
+ u_at_nodes, nodes, cur_nshoot:: Int , M:: Int , N:: Int , resida_len:: Int ,
81
+ residb_len:: Int , solve_internal_odes!:: S , bca:: B1 , bcb:: B2 , prob, u0,
82
+ ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs... ) where {B1, B2, S}
68
83
if __any_sparse_ad (alg. jac_alg)
69
84
J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
70
85
bcresid_prototype, u0, N, cur_nshoot)
@@ -74,17 +89,25 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_
74
89
similar (u_at_nodes, cur_nshoot * N), bcresid_prototype[2 ])
75
90
76
91
loss_fn = (du, u, p) -> __multiple_shooting_2point_loss! (du, u, p, cur_nshoot,
77
- nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb)
78
- loss_fnₚ = (du, u) -> loss_fn (du, u, prob . p )
92
+ nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb,
93
+ ode_cache_loss_fn )
79
94
80
95
sd_bvp = alg. jac_alg. diffmode isa AbstractSparseADType ?
81
96
__sparsity_detection_alg (J_proto) : NoSparsityDetection ()
82
97
83
98
resid_prototype_cached = similar (resid_prototype)
84
- jac_cache = sparse_jacobian_cache (alg. jac_alg. diffmode, sd_bvp, loss_fnₚ ,
99
+ jac_cache = sparse_jacobian_cache (alg. jac_alg. diffmode, sd_bvp, nothing ,
85
100
resid_prototype_cached, u_at_nodes)
86
101
jac_prototype = init_jacobian (jac_cache)
87
102
103
+ ode_cache_jac_fn = __multiple_shooting_init_jacobian_odecache (ensemblealg, prob,
104
+ jac_cache, alg. jac_alg. diffmode, alg. ode_alg, cur_nshoot, u0;
105
+ internal_ode_kwargs... )
106
+
107
+ loss_fnₚ = (du, u) -> __multiple_shooting_2point_loss! (du, u, prob. p, cur_nshoot,
108
+ nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb,
109
+ ode_cache_jac_fn)
110
+
88
111
jac_fn = (J, u, p) -> __multiple_shooting_2point_jacobian! (J, u, p, jac_cache,
89
112
loss_fnₚ, resid_prototype_cached, alg)
90
113
@@ -100,8 +123,9 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_
100
123
end
101
124
102
125
function __solve_nlproblem! (:: StandardBVProblem , alg:: MultipleShooting , bcresid_prototype,
103
- u_at_nodes, nodes, cur_nshoot, N, resid_len:: Int , solve_internal_odes!:: S , bc:: BC ,
104
- prob, f:: F , u0_size, u0, M; kwargs... ) where {BC, F, S}
126
+ u_at_nodes, nodes, cur_nshoot:: Int , M:: Int , N:: Int , resid_len:: Int ,
127
+ solve_internal_odes!:: S , bc:: BC , prob, f:: F , u0_size, u0, ode_cache_loss_fn,
128
+ ensemblealg, internal_ode_kwargs; kwargs... ) where {BC, F, S}
105
129
if __any_sparse_ad (alg. jac_alg)
106
130
J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
107
131
bcresid_prototype, u0, N, cur_nshoot)
@@ -114,23 +138,37 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
114
138
115
139
loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss! (du, u, p, cur_nshoot,
116
140
nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob. tspan,
117
- alg. ode_alg, u0)
141
+ alg. ode_alg, u0, ode_cache_loss_fn )
118
142
119
- ode_fn = (du, u) -> solve_internal_odes! (du, u, prob . p, cur_nshoot, nodes)
143
+ # ODE Part
120
144
sd_ode = alg. jac_alg. nonbc_diffmode isa AbstractSparseADType ?
121
145
__sparsity_detection_alg (J_proto) : NoSparsityDetection ()
122
146
ode_jac_cache = sparse_jacobian_cache (alg. jac_alg. nonbc_diffmode, sd_ode,
123
- ode_fn, similar (u_at_nodes, cur_nshoot * N), u_at_nodes)
124
-
125
- bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc! (du, u, prob. p, cur_nshoot, nodes,
126
- prob, solve_internal_odes!, N, f, bc, u0_size, tspan, alg. ode_alg, u0)
127
- sd_bc = alg. jac_alg. bc_diffmode isa AbstractSparseADType ?
128
- SymbolicsSparsityDetection () : NoSparsityDetection ()
147
+ nothing , similar (u_at_nodes, cur_nshoot * N), u_at_nodes)
148
+ ode_cache_ode_jac_fn = __multiple_shooting_init_jacobian_odecache (ensemblealg, prob,
149
+ ode_jac_cache, alg. jac_alg. nonbc_diffmode, alg. ode_alg, cur_nshoot, u0;
150
+ internal_ode_kwargs... )
151
+
152
+ # BC Part
153
+ if alg. jac_alg. bc_diffmode isa AbstractSparseADType
154
+ error (" Multiple Shooting doesn't support sparse AD for Boundary Conditions yet!" )
155
+ end
156
+ sd_bc = NoSparsityDetection ()
129
157
bc_jac_cache = sparse_jacobian_cache (alg. jac_alg. bc_diffmode,
130
- sd_bc, bc_fn, similar (bcresid_prototype), u_at_nodes)
158
+ sd_bc, nothing , similar (bcresid_prototype), u_at_nodes)
159
+ ode_cache_bc_jac_fn = __multiple_shooting_init_jacobian_odecache (ensemblealg, prob,
160
+ bc_jac_cache, alg. jac_alg. bc_diffmode, alg. ode_alg, cur_nshoot, u0;
161
+ internal_ode_kwargs... )
131
162
132
163
jac_prototype = vcat (init_jacobian (bc_jac_cache), init_jacobian (ode_jac_cache))
133
164
165
+ # Define the functions now
166
+ ode_fn = (du, u) -> solve_internal_odes! (du, u, prob. p, cur_nshoot, nodes,
167
+ ode_cache_ode_jac_fn)
168
+ bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc! (du, u, prob. p, cur_nshoot, nodes,
169
+ prob, solve_internal_odes!, N, f, bc, u0_size, prob. tspan, alg. ode_alg, u0,
170
+ ode_cache_bc_jac_fn)
171
+
134
172
jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian! (J, u, p,
135
173
similar (bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
136
174
ode_fn, bc_fn, alg, N, M)
@@ -146,36 +184,85 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
146
184
return nothing
147
185
end
148
186
149
- function __multiple_shooting_solve_internal_odes! (resid_nodes, us, p, _prob, f:: F ,
150
- cur_nshoots:: Int , nodes, tspan, u0_size, N, alg:: MultipleShooting ,
151
- ensemblealg, kwargs) where {F}
152
- iip = isinplace (_prob)
187
+ function __multiple_shooting_init_odecache (:: EnsembleSerial , prob, alg, u0, nshoots;
188
+ kwargs... )
189
+ odeprob = ODEProblem {isinplace(prob)} (prob. f, u0, prob. tspan, prob. p)
190
+ return SciMLBase. __init (odeprob, alg; kwargs... )
191
+ end
192
+
193
+ function __multiple_shooting_init_odecache (:: EnsembleThreads , prob, alg, u0, nshoots;
194
+ kwargs... )
195
+ odeprob = ODEProblem {isinplace(prob)} (prob. f, u0, prob. tspan, prob. p)
196
+ return [SciMLBase. __init (odeprob, alg; kwargs... )
197
+ for _ in 1 : min (Threads. nthreads (), nshoots)]
198
+ end
199
+
200
+ function __multiple_shooting_init_jacobian_odecache (ensemblealg, prob, jac_cache, ad, alg,
201
+ nshoots, u; kwargs... )
202
+ return __multiple_shooting_init_odecache (ensemblealg, prob, alg, u, nshoots;
203
+ kwargs... )
204
+ end
205
+
206
+ function __multiple_shooting_init_jacobian_odecache (ensemblealg, prob, jac_cache,
207
+ :: Union{AutoForwardDiff, AutoSparseForwardDiff} , alg, nshoots, u;
208
+ kwargs... )
209
+ cache = jac_cache. cache
210
+ if cache isa ForwardDiff. JacobianConfig
211
+ xduals = reshape (cache. duals[2 ][1 : length (u)], size (u))
212
+ else
213
+ xduals = reshape (cache. t[1 : length (u)], size (u))
214
+ end
215
+ fill! (xduals, 0 )
216
+ return __multiple_shooting_init_odecache (ensemblealg, prob, alg, xduals, nshoots;
217
+ kwargs... )
218
+ end
219
+
220
+ # Not using `EnsembleProblem` since it is hard to initialize the cache and stuff
221
+ function __multiple_shooting_solve_internal_odes! (resid_nodes, us, cur_nshoots:: Int ,
222
+ odecache, nodes, u0_size, N:: Int , :: EnsembleSerial )
153
223
ts_ = Vector {Vector{typeof(first(tspan))}} (undef, cur_nshoots)
154
224
us_ = Vector {Vector{typeof(us)}} (undef, cur_nshoots)
155
225
156
- function prob_func (probᵢ, i, _)
157
- return remake (probᵢ; u0 = reshape (@view (us[((i - 1 ) * N + 1 ): (i * N)]), u0_size),
158
- tspan = (nodes[i], nodes[i + 1 ]))
226
+ for i in 1 : cur_nshoots
227
+ SciMLBase. reinit! (odecache, reshape (@view (us[((i - 1 ) * N + 1 ): (i * N)]), u0_size);
228
+ t0 = nodes[i], tf = nodes[i + 1 ])
229
+ sol = solve! (odecache)
230
+ us_[i] = deepcopy (sol. u)
231
+ ts_[i] = deepcopy (sol. t)
232
+ resid_nodes[((i - 1 ) * N + 1 ): (i * N)] .= @view (us[(i * N + 1 ): ((i + 1 ) * N)]) .-
233
+ vec (sol. u[end ])
159
234
end
160
235
161
- function reduction (u, data, I)
162
- for i in I
163
- u. us[i] = data[i]. u
164
- u. ts[i] = data[i]. t
165
- u. resid[((i - 1 ) * N + 1 ): (i * N)] .= vec (@view (us[(i * N + 1 ): ((i + 1 ) * N)])) .-
166
- vec (data[i]. u[end ])
167
- end
168
- return (u, false )
169
- end
236
+ return reduce (vcat, us_), reduce (vcat, ts_)
237
+ end
170
238
171
- odeprob = ODEProblem {iip} (f, reshape (@view (us[1 : N]), u0_size), tspan, p)
239
+ function __multiple_shooting_solve_internal_odes! (resid_nodes, us, cur_nshoots:: Int ,
240
+ odecache:: Vector , nodes, u0_size, N:: Int , :: EnsembleThreads )
241
+ ts_ = Vector {Vector{typeof(first(tspan))}} (undef, cur_nshoots)
242
+ us_ = Vector {Vector{typeof(us)}} (undef, cur_nshoots)
172
243
173
- ensemble_prob = EnsembleProblem (odeprob; prob_func, reduction, safetycopy = false ,
174
- u_init = (; us = us_, ts = ts_, resid = resid_nodes))
175
- ensemble_sol = __solve (ensemble_prob, alg. ode_alg, ensemblealg; kwargs... ,
176
- trajectories = cur_nshoots)
244
+ n_splits = min (cur_nshoots, Threads. nthreads ())
245
+ n_per_chunk, n_remaining = divrem (cur_nshoots, n_splits)
246
+ data_partition = map (1 : n_splits) do i
247
+ first = 1 + (i - 1 ) * n_per_chunk + ifelse (i ≤ n_remaining, i - 1 , n_remaining)
248
+ last = (first - 1 ) + n_per_chunk + ifelse (i <= n_remaining, 1 , 0 )
249
+ return first: 1 : last
250
+ end
177
251
178
- return reduce (vcat, ensemble_sol. u. us), reduce (vcat, ensemble_sol. u. ts)
252
+ Threads. @threads for idx in 1 : length (data_partition)
253
+ cache = odecache[idx]
254
+ for i in data_partition[idx]
255
+ SciMLBase. reinit! (cache, reshape (@view (us[((i - 1 ) * N + 1 ): (i * N)]), u0_size);
256
+ t0 = nodes[i], tf = nodes[i + 1 ])
257
+ sol = solve! (cache)
258
+ us_[i] = deepcopy (sol. u)
259
+ ts_[i] = deepcopy (sol. t)
260
+ resid_nodes[((i - 1 ) * N + 1 ): (i * N)] .= @view (us[(i * N + 1 ): ((i + 1 ) * N)]) .-
261
+ vec (sol. u[end ])
262
+ end
263
+ end
264
+
265
+ return reduce (vcat, us_), reduce (vcat, ts_)
179
266
end
180
267
181
268
function __multiple_shooting_2point_jacobian! (J, us, p, jac_cache, loss_fn:: F , resid,
@@ -198,10 +285,10 @@ function __multiple_shooting_mpoint_jacobian!(J, us, p, resid_bc, resid_nodes,
198
285
end
199
286
200
287
@views function __multiple_shooting_2point_loss! (resid, us, p, cur_nshoots:: Int , nodes,
201
- prob, solve_internal_odes!:: S , resida_len, residb_len, N, bca:: BCA ,
202
- bcb :: BCB ) where {S, BCA, BCB}
288
+ prob, solve_internal_odes!:: S , resida_len, residb_len, N, bca:: BCA , bcb :: BCB ,
289
+ ode_cache ) where {S, BCA, BCB}
203
290
resid_ = resid[(resida_len + 1 ): (end - residb_len)]
204
- solve_internal_odes! (resid_, us, p, cur_nshoots, nodes)
291
+ solve_internal_odes! (resid_, us, p, cur_nshoots, nodes, ode_cache )
205
292
206
293
resid_bc_a = resid[1 : resida_len]
207
294
resid_bc_b = resid[(end - residb_len + 1 ): end ]
@@ -222,12 +309,12 @@ end
222
309
223
310
@views function __multiple_shooting_mpoint_loss_bc! (resid_bc, us, p, cur_nshoots:: Int ,
224
311
nodes, prob, solve_internal_odes!:: S , N, f:: F , bc:: BC , u0_size, tspan,
225
- ode_alg, u0) where {S, F, BC}
312
+ ode_alg, u0, ode_cache ) where {S, F, BC}
226
313
iip = isinplace (prob)
227
314
_resid_nodes = similar (us, cur_nshoots * N)
228
315
229
316
# NOTE: We need to recompute this to correctly propagate the dual numbers / gradients
230
- _us, _ts = solve_internal_odes! (_resid_nodes, us, p, cur_nshoots, nodes)
317
+ _us, _ts = solve_internal_odes! (_resid_nodes, us, p, cur_nshoots, nodes, ode_cache )
231
318
232
319
odeprob = ODEProblem {iip} (f, u0, tspan, p)
233
320
total_solution = SciMLBase. build_solution (odeprob, ode_alg, _ts, _us)
@@ -243,12 +330,12 @@ end
243
330
244
331
@views function __multiple_shooting_mpoint_loss! (resid, us, p, cur_nshoots:: Int , nodes,
245
332
prob, solve_internal_odes!:: S , resid_len, N, f:: F , bc:: BC , u0_size, tspan,
246
- ode_alg, u0) where {S, F, BC}
333
+ ode_alg, u0, ode_cache ) where {S, F, BC}
247
334
iip = isinplace (prob)
248
335
resid_bc = resid[1 : resid_len]
249
336
resid_nodes = resid[(resid_len + 1 ): end ]
250
337
251
- _us, _ts = solve_internal_odes! (resid_nodes, us, p, cur_nshoots, nodes)
338
+ _us, _ts = solve_internal_odes! (resid_nodes, us, p, cur_nshoots, nodes, ode_cache )
252
339
253
340
odeprob = ODEProblem {iip} (f, u0, tspan, p)
254
341
total_solution = SciMLBase. build_solution (odeprob, ode_alg, _ts, _us)
263
350
end
264
351
265
352
# Problem has initial guess
266
- @views function __multiple_shooting_initialize! (nodes, prob, alg, :: Val{true} , nshoots;
267
- kwargs... )
353
+ @views function __multiple_shooting_initialize! (nodes, prob, alg, :: Val{true} , nshoots:: Int ,
354
+ odecache; kwargs... )
268
355
@unpack u0, tspan = prob
269
356
270
357
resize! (nodes, nshoots + 1 )
279
366
280
367
# No initial guess
281
368
@views function __multiple_shooting_initialize! (nodes, prob, alg:: MultipleShooting ,
282
- :: Val{false} , nshoots; verbose, kwargs... )
369
+ :: Val{false} , nshoots:: Int , odecache_ ; verbose, kwargs... )
283
370
@unpack f, u0, tspan, p = prob
284
371
@unpack ode_alg = alg
285
372
298
385
end
299
386
300
387
# Assumes no initial guess for now
301
- start_prob = ODEProblem {isinplace(prob)} (f, u0, tspan, p)
302
- sol = __solve (start_prob, ode_alg; verbose, kwargs... , saveat = nodes)
388
+ odecache = odecache_ isa Vector ? first (odecache_) : odecache_
389
+ SciMLBase. reinit! (odecache, u0; t0 = tspan[1 ], tf = tspan[2 ])
390
+ sol = solve! (odecache)
303
391
304
392
if SciMLBase. successful_retcode (sol)
305
393
u_at_nodes[1 : N] .= vec (sol. u[1 ])
317
405
318
406
# Grid coarsening
319
407
@views function __multiple_shooting_initialize! (nodes, u_at_nodes_prev, prob, alg,
320
- nshoots, old_nshoots, ig; kwargs... )
408
+ nshoots, old_nshoots, ig, odecache_ ; kwargs... )
321
409
@unpack f, u0, tspan, p = prob
322
410
prev_nodes = copy (nodes)
411
+ odecache = odecache_ isa Vector ? first (odecache_) : odecache_
323
412
324
413
resize! (nodes, nshoots + 1 )
325
414
nodes .= range (tspan[1 ], tspan[2 ]; length = nshoots + 1 )
339
428
idxs_prev = (N + (ind - 2 ) * N .+ (1 : N))
340
429
u_at_nodes[idxs] .= u_at_nodes_prev[idxs_prev]
341
430
else
431
+ # TODO : Batch this computation and do it for all points between two nodes
432
+ # TODO : Though it is unlikely that this will be a bottleneck
342
433
# If the current node is not a node of the finer grid simulate from closest
343
434
# previous node and take result from simulation
344
435
fpos = floor (Int, pos)
351
442
idxs_prev = (N + (fpos - 2 ) * N .+ (1 : N))
352
443
ustart = u_at_nodes_prev[idxs_prev]
353
444
354
- # https://github.com/SciML/DifferentialEquations.jl/issues/975
355
- # odeprob = ODEProblem(f, ustart, (t0, tstop), p)
356
- odeprob = ODEProblem (f, copy (ustart), (t0, tstop), p)
357
- odesol = __solve (odeprob, alg. ode_alg; kwargs... , saveat = (), save_end = true )
445
+ SciMLBase. reinit! (odecache, ustart; t0, tf = tstop)
446
+ odesol = solve! (odecache)
358
447
359
448
u_at_nodes[idxs] .= odesol. u[end ]
360
449
end
0 commit comments