@@ -41,76 +41,107 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
41
41
cur_nshoot, all_nshoots[i - 1 ], ig; kwargs... , verbose, odesolve_kwargs... )
42
42
end
43
43
44
- if __any_sparse_ad (alg. jac_alg)
45
- J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
46
- bcresid_prototype, u0, N, cur_nshoot)
44
+ if prob. problem_type isa TwoPointBVProblem
45
+ __solve_nlproblem! (alg, bcresid_prototype, u_at_nodes, nodes, cur_nshoot, N,
46
+ resida_len, residb_len, Val (iip), solve_internal_odes!, bc[1 ], bc[2 ], prob,
47
+ u0; verbose, kwargs... , nlsolve_kwargs... )
48
+ else
49
+ __solve_nlproblem! (alg, bcresid_prototype, u_at_nodes, nodes, cur_nshoot, N,
50
+ prod (resid_size), Val (iip), solve_internal_odes!, bc, prob, f, u0_size, u0;
51
+ verbose, kwargs... , nlsolve_kwargs... )
47
52
end
53
+ end
48
54
49
- if prob. problem_type isa TwoPointBVProblem
50
- resid_prototype = vcat (bcresid_prototype[1 ],
51
- similar (u_at_nodes, cur_nshoot * N), bcresid_prototype[2 ])
55
+ single_shooting_prob = remake (prob; u0 = reshape (u_at_nodes[1 : N], u0_size))
56
+ return __solve (single_shooting_prob, Shooting (alg. ode_alg; alg. nlsolve);
57
+ odesolve_kwargs, nlsolve_kwargs, verbose, kwargs... )
58
+ end
52
59
53
- __resid_nodes = resid_prototype[(resida_len + 1 ): (resida_len + cur_nshoot * N)]
54
- resid_nodes = __maybe_allocate_diffcache (__resid_nodes,
55
- pickchunksize ((cur_nshoot + 1 ) * N), alg. jac_alg. diffmode)
60
+ function __solve_nlproblem! (alg:: MultipleShooting , bcresid_prototype, u_at_nodes, nodes,
61
+ cur_nshoot, N, resida_len, residb_len, iip:: Val , solve_internal_odes!:: S , bca:: B1 ,
62
+ bcb:: B2 , prob, u0; kwargs... ) where {B1, B2, S}
63
+ if __any_sparse_ad (alg. jac_alg)
64
+ J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
65
+ bcresid_prototype, u0, N, cur_nshoot)
66
+ end
56
67
57
- loss_fn = (du, u, p) -> __multiple_shooting_2point_loss! (du, u, p, cur_nshoot,
58
- nodes, Val (iip), solve_internal_odes!, resida_len, residb_len, N, bc[1 ],
59
- bc[2 ])
60
- loss_fnₚ = (du, u) -> loss_fn (du, u, prob. p)
68
+ resid_prototype = vcat (bcresid_prototype[1 ],
69
+ similar (u_at_nodes, cur_nshoot * N), bcresid_prototype[2 ])
61
70
62
- sd_bvp = alg. jac_alg. diffmode isa AbstractSparseADType ?
63
- __sparsity_detection_alg (J_proto) : NoSparsityDetection ()
71
+ __resid_nodes = resid_prototype[(resida_len + 1 ): (resida_len + cur_nshoot * N)]
72
+ resid_nodes = __maybe_allocate_diffcache (__resid_nodes,
73
+ pickchunksize ((cur_nshoot + 1 ) * N), alg. jac_alg. diffmode)
64
74
65
- resid_prototype_cached = similar (resid_prototype)
66
- jac_cache = sparse_jacobian_cache (alg . jac_alg . diffmode, sd_bvp, loss_fnₚ ,
67
- resid_prototype_cached, u_at_nodes )
68
- jac_prototype = init_jacobian (jac_cache )
75
+ loss_fn = (du, u, p) -> __multiple_shooting_2point_loss! (du, u, p, cur_nshoot,
76
+ nodes, iip, solve_internal_odes!, resida_len, residb_len, N, bca ,
77
+ bcb )
78
+ loss_fnₚ = (du, u) -> loss_fn (du, u, prob . p )
69
79
70
- jac_fn = (J, u, p) -> __multiple_shooting_2point_jacobian! (J, u, p, jac_cache,
71
- loss_fnₚ, resid_prototype_cached, alg)
72
- else
73
- resid_prototype = vcat (bcresid_prototype, similar (u_at_nodes, cur_nshoot * N))
74
-
75
- __resid_nodes = resid_prototype[(end - cur_nshoot * N + 1 ): end ]
76
- resid_nodes = __maybe_allocate_diffcache (__resid_nodes,
77
- pickchunksize ((cur_nshoot + 1 ) * N), alg. jac_alg. bc_diffmode)
78
-
79
- loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss! (du, u, p, cur_nshoot,
80
- nodes, Val (iip), solve_internal_odes!, prod (resid_size), N, f, bc, u0_size,
81
- tspan, alg. ode_alg)
82
-
83
- ode_fn = (du, u) -> solve_internal_odes! (du, u, prob. p, cur_nshoot, nodes)
84
- sd_ode = alg. jac_alg. nonbc_diffmode isa AbstractSparseADType ?
85
- __sparsity_detection_alg (J_proto) : NoSparsityDetection ()
86
- ode_jac_cache = sparse_jacobian_cache (alg. jac_alg. nonbc_diffmode, sd_ode,
87
- ode_fn, similar (u_at_nodes, cur_nshoot * N), u_at_nodes)
88
-
89
- bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc! (du, u, prob. p,
90
- cur_nshoot, nodes, Val (iip), solve_internal_odes!, N, f, bc, u0_size, tspan,
91
- alg. ode_alg)
92
- sd_bc = alg. jac_alg. bc_diffmode isa AbstractSparseADType ?
93
- SymbolicsSparsityDetection () : NoSparsityDetection ()
94
- bc_jac_cache = sparse_jacobian_cache (alg. jac_alg. bc_diffmode,
95
- sd_bc, bc_fn, similar (bcresid_prototype), u_at_nodes)
96
-
97
- jac_prototype = vcat (init_jacobian (bc_jac_cache), init_jacobian (ode_jac_cache))
98
-
99
- jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian! (J, u, p,
100
- similar (bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
101
- ode_fn, bc_fn, alg, N)
102
- end
103
- loss_function! = NonlinearFunction {true} (loss_fn; resid_prototype, jac = jac_fn,
104
- jac_prototype)
80
+ sd_bvp = alg. jac_alg. diffmode isa AbstractSparseADType ?
81
+ __sparsity_detection_alg (J_proto) : NoSparsityDetection ()
82
+
83
+ resid_prototype_cached = similar (resid_prototype)
84
+ jac_cache = sparse_jacobian_cache (alg. jac_alg. diffmode, sd_bvp, loss_fnₚ,
85
+ resid_prototype_cached, u_at_nodes)
86
+ jac_prototype = init_jacobian (jac_cache)
87
+
88
+ jac_fn = (J, u, p) -> __multiple_shooting_2point_jacobian! (J, u, p, jac_cache,
89
+ loss_fnₚ, resid_prototype_cached, alg)
90
+
91
+ loss_function! = NonlinearFunction {true} (loss_fn; resid_prototype, jac = jac_fn,
92
+ jac_prototype)
93
+
94
+ # NOTE: u_at_nodes is updated inplace
95
+ nlprob = NonlinearProblem (loss_function!, u_at_nodes, prob. p)
96
+ __solve (nlprob, alg. nlsolve; kwargs... , alias_u0 = true )
105
97
106
- # NOTE: u_at_nodes is updated inplace
107
- nlprob = NonlinearProblem (loss_function!, u_at_nodes, prob. p)
108
- __solve (nlprob, alg. nlsolve; verbose, kwargs... , nlsolve_kwargs... , alias_u0 = true )
98
+ return nothing
99
+ end
100
+
101
+ function __solve_nlproblem! (alg:: MultipleShooting , bcresid_prototype, u_at_nodes, nodes,
102
+ cur_nshoot, N, resid_size, iip:: Val , solve_internal_odes!:: S , bc:: BC ,
103
+ prob, f:: F , u0_size, u0; kwargs... ) where {BC, F, S}
104
+ if __any_sparse_ad (alg. jac_alg)
105
+ J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
106
+ bcresid_prototype, u0, N, cur_nshoot)
109
107
end
108
+ resid_prototype = vcat (bcresid_prototype, similar (u_at_nodes, cur_nshoot * N))
110
109
111
- single_shooting_prob = remake (prob; u0 = reshape (u_at_nodes[1 : N], u0_size))
112
- return __solve (single_shooting_prob, Shooting (alg. ode_alg; alg. nlsolve);
113
- odesolve_kwargs, nlsolve_kwargs, verbose, kwargs... )
110
+ __resid_nodes = resid_prototype[(end - cur_nshoot * N + 1 ): end ]
111
+ resid_nodes = __maybe_allocate_diffcache (__resid_nodes,
112
+ pickchunksize ((cur_nshoot + 1 ) * N), alg. jac_alg. bc_diffmode)
113
+
114
+ loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss! (du, u, p, cur_nshoot,
115
+ nodes, iip, solve_internal_odes!, prod (resid_size), N, f, bc, u0_size,
116
+ tspan, alg. ode_alg)
117
+
118
+ ode_fn = (du, u) -> solve_internal_odes! (du, u, prob. p, cur_nshoot, nodes)
119
+ sd_ode = alg. jac_alg. nonbc_diffmode isa AbstractSparseADType ?
120
+ __sparsity_detection_alg (J_proto) : NoSparsityDetection ()
121
+ ode_jac_cache = sparse_jacobian_cache (alg. jac_alg. nonbc_diffmode, sd_ode,
122
+ ode_fn, similar (u_at_nodes, cur_nshoot * N), u_at_nodes)
123
+
124
+ bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc! (du, u, prob. p,
125
+ cur_nshoot, nodes, iip, solve_internal_odes!, N, f, bc, u0_size, tspan, alg. ode_alg)
126
+ sd_bc = alg. jac_alg. bc_diffmode isa AbstractSparseADType ?
127
+ SymbolicsSparsityDetection () : NoSparsityDetection ()
128
+ bc_jac_cache = sparse_jacobian_cache (alg. jac_alg. bc_diffmode,
129
+ sd_bc, bc_fn, similar (bcresid_prototype), u_at_nodes)
130
+
131
+ jac_prototype = vcat (init_jacobian (bc_jac_cache), init_jacobian (ode_jac_cache))
132
+
133
+ jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian! (J, u, p,
134
+ similar (bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
135
+ ode_fn, bc_fn, alg, N)
136
+
137
+ loss_function! = NonlinearFunction {true} (loss_fn; resid_prototype, jac = jac_fn,
138
+ jac_prototype)
139
+
140
+ # NOTE: u_at_nodes is updated inplace
141
+ nlprob = NonlinearProblem (loss_function!, u_at_nodes, prob. p)
142
+ __solve (nlprob, alg. nlsolve; kwargs... , alias_u0 = true )
143
+
144
+ return nothing
114
145
end
115
146
116
147
function __multiple_shooting_solve_internal_odes! (resid_nodes, us, p, :: Val{iip} , f,
185
216
end
186
217
187
218
@views function __multiple_shooting_mpoint_loss_bc! (resid_bc, us, p, cur_nshoots:: Int ,
188
- nodes,
189
- :: Val{iip} , solve_internal_odes!, N, f, bc, u0_size, tspan, ode_alg) where {iip}
219
+ nodes, :: Val{iip} , solve_internal_odes!, N, f, bc, u0_size, tspan, ode_alg) where {iip}
190
220
_resid_nodes = similar (us, cur_nshoots * N)
191
221
192
222
# NOTE: We need to recompute this to correctly propagate the dual numbers / gradients
313
343
idxs_prev = (N + (fpos - 2 ) * N .+ (1 : N))
314
344
ustart = u_at_nodes_prev[idxs_prev]
315
345
316
- odeprob = ODEProblem (f, ustart, (t0, tstop), p)
346
+ # https://github.com/SciML/DifferentialEquations.jl/issues/975
347
+ # odeprob = ODEProblem(f, ustart, (t0, tstop), p)
348
+ odeprob = ODEProblem (f, copy (ustart), (t0, tstop), p)
317
349
odesol = __solve (odeprob, alg. ode_alg; kwargs... , saveat = (), save_end = true )
318
350
319
351
u_at_nodes[idxs] .= odesol. u[end ]
324
356
end
325
357
326
358
@inline function __get_all_nshoots (g:: Bool , nshoots)
327
- return g ? __get_all_nshoots (Base. Fix2 (÷ , 2 )) : [nshoots]
359
+ return g ? __get_all_nshoots (Base. Fix2 (÷ , 2 ), nshoots ) : [nshoots]
328
360
end
329
361
@inline function __get_all_nshoots (g, nshoots)
330
362
first (g) == nshoots && return g
0 commit comments