@@ -2,11 +2,78 @@ function __solve(prob::BVProblem, _alg::MultipleShooting{true}; odesolve_kwargs
22 nlsolve_kwargs = (;), ensemblealg = EnsembleThreads (), verbose = true , kwargs... )
33 # For TwoPointBVPs there is nothing to do. Forward to general multiple shooting
44 prob. problem_type isa TwoPointBVProblem &&
5- return __solve_internal (prob, _alg; odesolve_kwargs, nlsolve_kwargs, ensemblealg,
6- verbose, kwargs... )
5+ return __solve_internal (prob, __without_static_nodes (_alg); odesolve_kwargs,
6+ nlsolve_kwargs, ensemblealg, verbose, kwargs... )
7+
8+ ig, T, N, Nig, u0 = __extract_problem_details (prob; dt = 0.1 )
9+
10+ if _unwrap_val (ig) && prob. u0 isa AbstractVector
11+ if verbose
12+ @warn " Static Nodes for Multiple-Shooting is not supported when Vector of \
13+ initial guesses are provided. Falling back to using the generic method!"
14+ end
15+ return __solve_internal (prob, __without_static_nodes (_alg); odesolve_kwargs,
16+ nlsolve_kwargs, ensemblealg, verbose, kwargs... )
17+ end
18+
19+ has_initial_guess = _unwrap_val (ig)
20+
21+ bcresid_prototype, resid_size = __get_bcresid_prototype (prob, u0)
22+ iip, bc, u0, u0_size = isinplace (prob), prob. f. bc, deepcopy (u0), size (u0)
723
824 # Extract the time-points used in BC
9- _prob = ODEProblem {isinplace(prob)} (prob. f, prob. u0, prob. tspan, prob. p)
25+ _prob = ODEProblem {iip} (prob. f, prob. u0, prob. tspan, prob. p)
26+ _fake_ode_sol = __construct_fake_ode_solution (_prob, _alg. ode_alg)
27+ if iip
28+ bc (bcresid_prototype, _fake_ode_sol, prob. p, _fake_ode_sol. sol. t)
29+ else
30+ bc (_fake_ode_sol, prob. p, _fake_ode_sol. sol. t)
31+ end
32+ __finalize_nodes! (_fake_ode_sol)
33+
34+ __alg = concretize_jacobian_algorithm (_alg, prob)
35+ alg = if has_initial_guess && Nig != __alg. nshoots
36+ verbose &&
37+ @warn " Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(Nig) `"
38+ update_nshoots (__alg, Nig)
39+ else
40+ __alg
41+ end
42+ nshoots = alg. nshoots
43+ M = length (bcresid_prototype)
44+
45+ internal_ode_kwargs = (; verbose, kwargs... , odesolve_kwargs... , save_end = true )
46+
47+ function solve_internal_odes! (resid_nodes:: T1 , us:: T2 , p:: T3 , cur_nshoot:: Int ,
48+ nodes:: T4 , odecache:: C ) where {T1, T2, T3, T4, C}
49+ return __multiple_shooting_solve_internal_odes! (resid_nodes, us, cur_nshoot,
50+ odecache, nodes, u0_size, N, ensemblealg)
51+ end
52+
53+ ode_cache_loss_fn = __multiple_shooting_init_odecache (ensemblealg, prob,
54+ alg. ode_alg, u0, nshoots; internal_ode_kwargs... )
55+
56+ nodes = typeof (first (tspan))[]
57+ u_at_nodes = __multiple_shooting_initialize! (nodes, prob, alg, ig, nshoots,
58+ ode_cache_loss_fn; kwargs... , verbose, odesolve_kwargs... ,
59+ static_nodes = _fake_ode_sol. nodes)
60+
61+ __solve_nlproblem! (prob. problem_type, alg, bcresid_prototype, u_at_nodes, nodes,
62+ nshoots, M, N, prod (resid_size), solve_internal_odes!, bc, prob, prob. f,
63+ u0_size, u0, ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; verbose,
64+ kwargs... , nlsolve_kwargs... )
65+
66+ if prob. problem_type isa TwoPointBVProblem
67+ diffmode_shooting = __get_non_sparse_ad (alg. jac_alg. diffmode)
68+ else
69+ diffmode_shooting = __get_non_sparse_ad (alg. jac_alg. bc_diffmode)
70+ end
71+ shooting_alg = Shooting (alg. ode_alg, alg. nlsolve,
72+ BVPJacobianAlgorithm (diffmode_shooting))
73+
74+ single_shooting_prob = remake (prob; u0 = reshape (@view (u_at_nodes[1 : N]), u0_size))
75+ return __solve (single_shooting_prob, shooting_alg; odesolve_kwargs, nlsolve_kwargs,
76+ verbose, kwargs... )
1077end
1178
1279function __solve (prob:: BVProblem , _alg:: MultipleShooting{false} ; kwargs... )
@@ -145,10 +212,71 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_
145212 return nothing
146213end
147214
148- function __solve_nlproblem! (:: StandardBVProblem , alg:: MultipleShooting , bcresid_prototype,
149- u_at_nodes, nodes, cur_nshoot:: Int , M:: Int , N:: Int , resid_len:: Int ,
150- solve_internal_odes!:: S , bc:: BC , prob, f:: F , u0_size, u0, ode_cache_loss_fn,
151- ensemblealg, internal_ode_kwargs; kwargs... ) where {BC, F, S}
215+ function __solve_nlproblem! (:: StandardBVProblem , alg:: MultipleShooting{true} ,
216+ bcresid_prototype, u_at_nodes, nodes, cur_nshoot:: Int , M:: Int , N:: Int ,
217+ resid_len:: Int , solve_internal_odes!:: S , bc:: BC , prob, f:: F , u0_size, u0,
218+ ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs... ) where {BC, F, S}
219+ if __any_sparse_ad (alg. jac_alg)
220+ J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
221+ bcresid_prototype, u0, N, cur_nshoot)
222+ end
223+ resid_prototype = vcat (bcresid_prototype, similar (u_at_nodes, cur_nshoot * N))
224+
225+ __resid_nodes = resid_prototype[(end - cur_nshoot * N + 1 ): end ]
226+ resid_nodes = __maybe_allocate_diffcache (__resid_nodes,
227+ pickchunksize ((cur_nshoot + 1 ) * N), alg. jac_alg. bc_diffmode)
228+
229+ loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss! (du, u, p, cur_nshoot,
230+ nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob. tspan,
231+ alg. ode_alg, u0, ode_cache_loss_fn)
232+
233+ # ODE Part
234+ sd_ode = alg. jac_alg. nonbc_diffmode isa AbstractSparseADType ?
235+ __sparsity_detection_alg (J_proto) : NoSparsityDetection ()
236+ ode_jac_cache = sparse_jacobian_cache (alg. jac_alg. nonbc_diffmode, sd_ode,
237+ nothing , similar (u_at_nodes, cur_nshoot * N), u_at_nodes)
238+ ode_cache_ode_jac_fn = __multiple_shooting_init_jacobian_odecache (ensemblealg, prob,
239+ ode_jac_cache, alg. jac_alg. nonbc_diffmode, alg. ode_alg, cur_nshoot, u0;
240+ internal_ode_kwargs... )
241+
242+ # BC Part
243+ sd_bc = alg. jac_alg. bc_diffmode isa AbstractSparseADType ?
244+ SymbolicsSparsityDetection () : NoSparsityDetection ()
245+ bc_jac_cache = sparse_jacobian_cache (alg. jac_alg. bc_diffmode,
246+ sd_bc, nothing , similar (bcresid_prototype), u_at_nodes)
247+ ode_cache_bc_jac_fn = __multiple_shooting_init_jacobian_odecache (ensemblealg, prob,
248+ bc_jac_cache, alg. jac_alg. bc_diffmode, alg. ode_alg, cur_nshoot, u0;
249+ internal_ode_kwargs... )
250+
251+ jac_prototype = vcat (init_jacobian (bc_jac_cache), init_jacobian (ode_jac_cache))
252+
253+ # Define the functions now
254+ ode_fn = (du, u) -> solve_internal_odes! (du, u, prob. p, cur_nshoot, nodes,
255+ ode_cache_ode_jac_fn)
256+ bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc_static_node! (du, u, prob. p,
257+ cur_nshoot, nodes,
258+ prob, solve_internal_odes!, N, f, bc, u0_size, prob. tspan, alg. ode_alg, u0,
259+ ode_cache_bc_jac_fn)
260+
261+ jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian! (J, u, p,
262+ similar (bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
263+ ode_fn, bc_fn, alg, N, M)
264+
265+ loss_function! = NonlinearFunction {true} (loss_fn; resid_prototype, jac = jac_fn,
266+ jac_prototype)
267+
268+ # NOTE: u_at_nodes is updated inplace
269+ nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!,
270+ u_at_nodes, prob. p)
271+ __solve (nlprob, alg. nlsolve; kwargs... , alias_u0 = true )
272+
273+ return nothing
274+ end
275+
276+ function __solve_nlproblem! (:: StandardBVProblem , alg:: MultipleShooting{false} ,
277+ bcresid_prototype, u_at_nodes, nodes, cur_nshoot:: Int , M:: Int , N:: Int ,
278+ resid_len:: Int , solve_internal_odes!:: S , bc:: BC , prob, f:: F , u0_size, u0,
279+ ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs... ) where {BC, F, S}
152280 if __any_sparse_ad (alg. jac_alg)
153281 J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
154282 bcresid_prototype, u0, N, cur_nshoot)
348476 return nothing
349477end
350478
479+ @views function __multiple_shooting_mpoint_loss_bc_static_node! (resid_bc, us, p,
480+ cur_nshoots:: Int , nodes, prob, solve_internal_odes!:: S , N, f:: F , bc:: BC , u0_size,
481+ tspan, ode_alg, u0, ode_cache) where {S, F, BC}
482+ iip = isinplace (prob)
483+
484+ # NOTE: We placed the nodes at the points `bc` is evaluated so we don't need to
485+ # recompute the solution
486+ _ts = nodes
487+ _us = [reshape (us[((i - 1 ) * prod (u0_size) + 1 ): (i * prod (u0_size))], u0_size)
488+ for i in eachindex (_ts)]
489+
490+ odeprob = ODEProblem {iip} (f, u0, tspan, p)
491+ total_solution = SciMLBase. build_solution (odeprob, ode_alg, _ts, _us)
492+
493+ if iip
494+ eval_bc_residual! (resid_bc, StandardBVProblem (), bc, total_solution, p)
495+ else
496+ resid_bc .= eval_bc_residual (StandardBVProblem (), bc, total_solution, p)
497+ end
498+
499+ return nothing
500+ end
501+
351502@views function __multiple_shooting_mpoint_loss! (resid, us, p, cur_nshoots:: Int , nodes,
352503 prob, solve_internal_odes!:: S , resid_len, N, f:: F , bc:: BC , u0_size, tspan,
353504 ode_alg, u0, ode_cache) where {S, F, BC}
@@ -390,12 +541,22 @@ end
390541
391542# No initial guess
392543@views function __multiple_shooting_initialize! (nodes, prob, alg:: MultipleShooting ,
393- :: Val{false} , nshoots:: Int , odecache_; verbose, kwargs... )
544+ :: Val{false} , nshoots:: Int , odecache_; verbose, static_nodes = nothing , kwargs... )
394545 @unpack f, u0, tspan, p = prob
395546 @unpack ode_alg = alg
396547
397548 resize! (nodes, nshoots + 1 )
398549 nodes .= range (tspan[1 ], tspan[2 ]; length = nshoots + 1 )
550+
551+ if static_nodes != = nothing
552+ idx = 1
553+ for snode in static_nodes
554+ sidx = searchsortedfirst (nodes[idx: end ], snode)
555+ nodes[idx + sidx - 1 ] = snode
556+ idx = sidx + 1
557+ end
558+ end
559+
399560 N = length (u0)
400561
401562 # Ensures type stability in case the parameters are dual numbers
0 commit comments