Skip to content

Commit 09eb18d

Browse files
committed
Fix some more runtime dispatches
1 parent 73d8385 commit 09eb18d

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

ext/BoundaryValueDiffEqOrdinaryDiffEqExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ end
4040
]
4141

4242
algs = [
43-
Shooting(Tsit5()),
43+
Shooting(Tsit5();
44+
nlsolve = NewtonRaphson(; autodiff = AutoForwardDiff(chunksize = 2))),
4445
MultipleShooting(10,
4546
Tsit5();
47+
nlsolve = NewtonRaphson(; autodiff = AutoForwardDiff(chunksize = 2)),
4648
jac_alg = BVPJacobianAlgorithm(;
4749
bc_diffmode = AutoForwardDiff(; chunksize = 2),
4850
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))),

src/solve/multiple_shooting.jl

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
2424
end
2525

2626
internal_ode_kwargs = (; verbose, kwargs..., odesolve_kwargs..., save_end = true)
27-
solve_internal_odes! = (resid_nodes, us, p, cur_nshoot, nodes) -> __multiple_shooting_solve_internal_odes!(resid_nodes,
28-
us, p, Val(iip), f, cur_nshoot, nodes, tspan, u0_size, N, alg, ensemblealg,
29-
internal_ode_kwargs)
27+
function solve_internal_odes!(resid_nodes::T1, us::T2, p::T3, cur_nshoot::Int,
28+
nodes::T4) where {T1, T2, T3, T4}
29+
return __multiple_shooting_solve_internal_odes!(resid_nodes, us, p, prob, f,
30+
cur_nshoot, nodes, tspan, u0_size, N, alg, ensemblealg, internal_ode_kwargs)
31+
end
3032

3133
# This gets all the nshoots except the final SingleShooting case
3234
all_nshoots = __get_all_nshoots(alg.grid_coarsening, nshoots)
@@ -42,13 +44,13 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
4244
end
4345

4446
if prob.problem_type isa TwoPointBVProblem
45-
__solve_nlproblem!(alg, bcresid_prototype, u_at_nodes, nodes, cur_nshoot, N,
46-
resida_len, residb_len, Val(iip), solve_internal_odes!, bc[1], bc[2], prob,
47-
u0; verbose, kwargs..., nlsolve_kwargs...)
47+
__solve_nlproblem!(prob.problem_type, alg, bcresid_prototype, u_at_nodes, nodes,
48+
cur_nshoot, N, resida_len, residb_len, solve_internal_odes!, bc[1], bc[2],
49+
prob, u0; verbose, kwargs..., nlsolve_kwargs...)
4850
else
49-
__solve_nlproblem!(alg, bcresid_prototype, u_at_nodes, nodes, cur_nshoot, N,
50-
prod(resid_size), Val(iip), solve_internal_odes!, bc, prob, f, u0_size, u0;
51-
verbose, kwargs..., nlsolve_kwargs...)
51+
__solve_nlproblem!(prob.problem_type, alg, bcresid_prototype, u_at_nodes, nodes,
52+
cur_nshoot, N, prod(resid_size), solve_internal_odes!, bc, prob, f,
53+
u0_size, u0; verbose, kwargs..., nlsolve_kwargs...)
5254
end
5355
end
5456

@@ -57,9 +59,9 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
5759
odesolve_kwargs, nlsolve_kwargs, verbose, kwargs...)
5860
end
5961

60-
function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes, nodes,
61-
cur_nshoot, N, resida_len, residb_len, iip::Val, solve_internal_odes!::S, bca::B1,
62-
bcb::B2, prob, u0; kwargs...) where {B1, B2, S}
62+
function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_prototype,
63+
u_at_nodes, nodes, cur_nshoot::Int, N::Int, resida_len::Int, residb_len::Int,
64+
solve_internal_odes!::S, bca::B1, bcb::B2, prob, u0; kwargs...) where {B1, B2, S}
6365
if __any_sparse_ad(alg.jac_alg)
6466
J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type,
6567
bcresid_prototype, u0, N, cur_nshoot)
@@ -69,7 +71,7 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
6971
similar(u_at_nodes, cur_nshoot * N), bcresid_prototype[2])
7072

7173
loss_fn = (du, u, p) -> __multiple_shooting_2point_loss!(du, u, p, cur_nshoot,
72-
nodes, iip, solve_internal_odes!, resida_len, residb_len, N, bca, bcb)
74+
nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb)
7375
loss_fnₚ = (du, u) -> loss_fn(du, u, prob.p)
7476

7577
sd_bvp = alg.jac_alg.diffmode isa AbstractSparseADType ?
@@ -93,8 +95,8 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
9395
return nothing
9496
end
9597

96-
function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes, nodes,
97-
cur_nshoot, N, resid_size, iip::Val, solve_internal_odes!::S, bc::BC,
98+
function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_prototype,
99+
u_at_nodes, nodes, cur_nshoot, N, resid_len::Int, solve_internal_odes!::S, bc::BC,
98100
prob, f::F, u0_size, u0; kwargs...) where {BC, F, S}
99101
if __any_sparse_ad(alg.jac_alg)
100102
J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type,
@@ -107,18 +109,17 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
107109
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode)
108110

109111
loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot,
110-
nodes, iip, solve_internal_odes!, prod(resid_size), N, f, bc, u0_size,
111-
tspan, alg.ode_alg, u0)
112+
nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob.tspan,
113+
alg.ode_alg, u0)
112114

113115
ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes)
114116
sd_ode = alg.jac_alg.nonbc_diffmode isa AbstractSparseADType ?
115117
__sparsity_detection_alg(J_proto) : NoSparsityDetection()
116118
ode_jac_cache = sparse_jacobian_cache(alg.jac_alg.nonbc_diffmode, sd_ode,
117119
ode_fn, similar(u_at_nodes, cur_nshoot * N), u_at_nodes)
118120

119-
bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc!(du, u, prob.p,
120-
cur_nshoot, nodes, iip, solve_internal_odes!, N, f, bc, u0_size, tspan, alg.ode_alg,
121-
u0)
121+
bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc!(du, u, prob.p, cur_nshoot, nodes,
122+
prob, solve_internal_odes!, N, f, bc, u0_size, tspan, alg.ode_alg, u0)
122123
sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ?
123124
SymbolicsSparsityDetection() : NoSparsityDetection()
124125
bc_jac_cache = sparse_jacobian_cache(alg.jac_alg.bc_diffmode,
@@ -140,9 +141,10 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
140141
return nothing
141142
end
142143

143-
function __multiple_shooting_solve_internal_odes!(resid_nodes, us, p, ::Val{iip}, f::F,
144+
function __multiple_shooting_solve_internal_odes!(resid_nodes, us, p, _prob, f::F,
144145
cur_nshoots::Int, nodes, tspan, u0_size, N, alg::MultipleShooting,
145-
ensemblealg, kwargs) where {iip, F}
146+
ensemblealg, kwargs) where {F}
147+
iip = isinplace(_prob)
146148
ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots)
147149
us_ = Vector{Vector{typeof(us)}}(undef, cur_nshoots)
148150

@@ -171,14 +173,15 @@ function __multiple_shooting_solve_internal_odes!(resid_nodes, us, p, ::Val{iip}
171173
return reduce(vcat, ensemble_sol.u.us), reduce(vcat, ensemble_sol.u.ts)
172174
end
173175

174-
function __multiple_shooting_2point_jacobian!(J, us, p, jac_cache, loss_fn, resid,
175-
alg::MultipleShooting)
176+
function __multiple_shooting_2point_jacobian!(J, us, p, jac_cache, loss_fn::F, resid,
177+
alg::MultipleShooting) where {F}
176178
sparse_jacobian!(J, alg.jac_alg.diffmode, jac_cache, loss_fn, resid, us)
177179
return nothing
178180
end
179181

180182
function __multiple_shooting_mpoint_jacobian!(J, us, p, resid_bc, resid_nodes,
181-
ode_jac_cache, bc_jac_cache, ode_fn, bc_fn, alg::MultipleShooting, N::Int)
183+
ode_jac_cache, bc_jac_cache, ode_fn::F1, bc_fn::F2, alg::MultipleShooting,
184+
N::Int) where {F1, F2}
182185
J_bc = @view(J[1:N, :])
183186
J_c = @view(J[(N + 1):end, :])
184187

@@ -190,8 +193,8 @@ function __multiple_shooting_mpoint_jacobian!(J, us, p, resid_bc, resid_nodes,
190193
end
191194

192195
@views function __multiple_shooting_2point_loss!(resid, us, p, cur_nshoots::Int, nodes,
193-
::Val{iip}, solve_internal_odes!::S, resida_len, residb_len, N, bca,
194-
bcb) where {iip, S}
196+
prob, solve_internal_odes!::S, resida_len, residb_len, N, bca::BCA,
197+
bcb::BCB) where {S, BCA, BCB}
195198
resid_ = resid[(resida_len + 1):(end - residb_len)]
196199
solve_internal_odes!(resid_, us, p, cur_nshoots, nodes)
197200

@@ -201,7 +204,7 @@ end
201204
ua = us[1:N]
202205
ub = us[(end - N + 1):end]
203206

204-
if iip
207+
if isinplace(prob)
205208
bca(resid_bc_a, ua, p)
206209
bcb(resid_bc_b, ub, p)
207210
else
@@ -213,8 +216,9 @@ end
213216
end
214217

215218
@views function __multiple_shooting_mpoint_loss_bc!(resid_bc, us, p, cur_nshoots::Int,
216-
nodes, ::Val{iip}, solve_internal_odes!::S, N, f, bc, u0_size, tspan,
217-
ode_alg, u0) where {iip, S}
219+
nodes, prob, solve_internal_odes!::S, N, f::F, bc::BC, u0_size, tspan,
220+
ode_alg, u0) where {S, F, BC}
221+
iip = isinplace(prob)
218222
_resid_nodes = similar(us, cur_nshoots * N)
219223

220224
# NOTE: We need to recompute this to correctly propagate the dual numbers / gradients
@@ -233,8 +237,9 @@ end
233237
end
234238

235239
@views function __multiple_shooting_mpoint_loss!(resid, us, p, cur_nshoots::Int, nodes,
236-
::Val{iip}, solve_internal_odes!::S, resid_len, N, f, bc, u0_size, tspan,
237-
ode_alg, u0) where {iip, S}
240+
prob, solve_internal_odes!::S, resid_len, N, f::F, bc::BC, u0_size, tspan,
241+
ode_alg, u0) where {S, F, BC}
242+
iip = isinplace(prob)
238243
resid_bc = resid[1:resid_len]
239244
resid_nodes = resid[(resid_len + 1):end]
240245

0 commit comments

Comments
 (0)