Skip to content

Commit 2b22be3

Browse files
committed
A bit less runtime dispatches
1 parent 9d2b4d7 commit 2b22be3

File tree

4 files changed

+35
-26
lines changed

4 files changed

+35
-26
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ version = "1.10.0"
802802

803803
[[deps.SparseDiffTools]]
804804
deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"]
805-
git-tree-sha1 = "766a1f0d19232243de203b84d1a7713f6624a1e4"
805+
git-tree-sha1 = "fcc87cf2750372313d0f761b7dbd75b559c646a4"
806806
repo-rev = "patch-1"
807807
repo-url = "https://github.com/avik-pal/SparseDiffTools.jl"
808808
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

ext/BoundaryValueDiffEqOrdinaryDiffEqExt.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,17 @@ end
3939
TwoPointBVProblem(f1, (bc1_a, bc1_b), u0, tspan; bcresid_prototype),
4040
]
4141

42+
algs = [
43+
Shooting(Tsit5()),
44+
MultipleShooting(10,
45+
Tsit5();
46+
jac_alg = BVPJacobianAlgorithm(;
47+
bc_diffmode = AutoForwardDiff(; chunksize = 2),
48+
nonbc_diffmode = AutoSparseForwardDiff(; chunksize = 2))),
49+
]
50+
4251
@compile_workload begin
43-
for prob in probs, alg in (Shooting(Tsit5()), MultipleShooting(10, Tsit5()))
52+
for prob in probs, alg in algs
4453
solve(prob, alg)
4554
end
4655
end

src/solve/multiple_shooting.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,8 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
6868
resid_prototype = vcat(bcresid_prototype[1],
6969
similar(u_at_nodes, cur_nshoot * N), bcresid_prototype[2])
7070

71-
__resid_nodes = resid_prototype[(resida_len + 1):(resida_len + cur_nshoot * N)]
72-
resid_nodes = __maybe_allocate_diffcache(__resid_nodes,
73-
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.diffmode)
74-
7571
loss_fn = (du, u, p) -> __multiple_shooting_2point_loss!(du, u, p, cur_nshoot,
76-
nodes, iip, solve_internal_odes!, resida_len, residb_len, N, bca,
77-
bcb)
72+
nodes, iip, solve_internal_odes!, resida_len, residb_len, N, bca, bcb)
7873
loss_fnₚ = (du, u) -> loss_fn(du, u, prob.p)
7974

8075
sd_bvp = alg.jac_alg.diffmode isa AbstractSparseADType ?
@@ -113,7 +108,7 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
113108

114109
loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot,
115110
nodes, iip, solve_internal_odes!, prod(resid_size), N, f, bc, u0_size,
116-
tspan, alg.ode_alg)
111+
tspan, alg.ode_alg, u0)
117112

118113
ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes)
119114
sd_ode = alg.jac_alg.nonbc_diffmode isa AbstractSparseADType ?
@@ -122,7 +117,8 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
122117
ode_fn, similar(u_at_nodes, cur_nshoot * N), u_at_nodes)
123118

124119
bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc!(du, u, prob.p,
125-
cur_nshoot, nodes, iip, solve_internal_odes!, N, f, bc, u0_size, tspan, alg.ode_alg)
120+
cur_nshoot, nodes, iip, solve_internal_odes!, N, f, bc, u0_size, tspan, alg.ode_alg,
121+
u0)
126122
sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ?
127123
SymbolicsSparsityDetection() : NoSparsityDetection()
128124
bc_jac_cache = sparse_jacobian_cache(alg.jac_alg.bc_diffmode,
@@ -144,9 +140,9 @@ function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes
144140
return nothing
145141
end
146142

147-
function __multiple_shooting_solve_internal_odes!(resid_nodes, us, p, ::Val{iip}, f,
143+
function __multiple_shooting_solve_internal_odes!(resid_nodes, us, p, ::Val{iip}, f::F,
148144
cur_nshoots::Int, nodes, tspan, u0_size, N, alg::MultipleShooting,
149-
ensemblealg, kwargs) where {iip}
145+
ensemblealg, kwargs) where {iip, F}
150146
ts_ = Vector{Vector{typeof(first(tspan))}}(undef, cur_nshoots)
151147
us_ = Vector{Vector{typeof(us)}}(undef, cur_nshoots)
152148

@@ -194,7 +190,7 @@ function __multiple_shooting_mpoint_jacobian!(J, us, p, resid_bc, resid_nodes,
194190
end
195191

196192
@views function __multiple_shooting_2point_loss!(resid, us, p, cur_nshoots::Int, nodes,
197-
::Val{iip}, solve_internal_odes!, resida_len, residb_len, N, bca, bcb) where {iip}
193+
::Val{iip}, solve_internal_odes!::S, resida_len, residb_len, N, bca, bcb) where {iip, S}
198194
resid_ = resid[(resida_len + 1):(end - residb_len)]
199195
solve_internal_odes!(resid_, us, p, cur_nshoots, nodes)
200196

@@ -216,13 +212,14 @@ end
216212
end
217213

218214
@views function __multiple_shooting_mpoint_loss_bc!(resid_bc, us, p, cur_nshoots::Int,
219-
nodes, ::Val{iip}, solve_internal_odes!, N, f, bc, u0_size, tspan, ode_alg) where {iip}
215+
nodes, ::Val{iip}, solve_internal_odes!::S, N, f, bc, u0_size, tspan,
216+
ode_alg, u0) where {iip, S}
220217
_resid_nodes = similar(us, cur_nshoots * N)
221218

222219
# NOTE: We need to recompute this to correctly propagate the dual numbers / gradients
223220
_us, _ts = solve_internal_odes!(_resid_nodes, us, p, cur_nshoots, nodes)
224221

225-
odeprob = ODEProblem{iip}(f, reshape(us[1:N], u0_size), tspan, p)
222+
odeprob = ODEProblem{iip}(f, u0, tspan, p)
226223
total_solution = SciMLBase.build_solution(odeprob, ode_alg, _ts, _us)
227224

228225
if iip
@@ -235,14 +232,14 @@ end
235232
end
236233

237234
@views function __multiple_shooting_mpoint_loss!(resid, us, p, cur_nshoots::Int, nodes,
238-
::Val{iip}, solve_internal_odes!, resid_len, N, f, bc, u0_size, tspan,
239-
ode_alg) where {iip}
235+
::Val{iip}, solve_internal_odes!::S, resid_len, N, f, bc, u0_size, tspan,
236+
ode_alg, u0) where {iip, S}
240237
resid_bc = resid[1:resid_len]
241238
resid_nodes = resid[(resid_len + 1):end]
242239

243240
_us, _ts = solve_internal_odes!(resid_nodes, us, p, cur_nshoots, nodes)
244241

245-
odeprob = ODEProblem{iip}(f, reshape(us[1:N], u0_size), tspan, p)
242+
odeprob = ODEProblem{iip}(f, u0, tspan, p)
246243
total_solution = SciMLBase.build_solution(odeprob, ode_alg, _ts, _us)
247244

248245
if iip

src/utils.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,30 @@ function __maybe_matmul!(z, A, b, α = eltype(z)(1), β = eltype(z)(0))
7171
end
7272

7373
## Easier to dispatch
74-
eval_bc_residual(pt, bc, sol, p) = eval_bc_residual(pt, bc, sol, p, sol.t)
75-
eval_bc_residual(_, bc, sol, p, t) = bc(sol, p, t)
76-
function eval_bc_residual(::TwoPointBVProblem, (bca, bcb), sol, p, t)
74+
eval_bc_residual(pt, bc::BC, sol, p) where {BC} = eval_bc_residual(pt, bc, sol, p, sol.t)
75+
eval_bc_residual(_, bc::BC, sol, p, t) where {BC} = bc(sol, p, t)
76+
function eval_bc_residual(::TwoPointBVProblem, (bca, bcb)::BC, sol, p, t) where {BC}
7777
ua = sol isa AbstractVector ? sol[1] : sol(first(t))
7878
ub = sol isa AbstractVector ? sol[end] : sol(last(t))
7979
resida = bca(ua, p)
8080
residb = bcb(ub, p)
8181
return (resida, residb)
8282
end
8383

84-
eval_bc_residual!(resid, pt, bc!, sol, p) = eval_bc_residual!(resid, pt, bc!, sol, p, sol.t)
85-
eval_bc_residual!(resid, _, bc!, sol, p, t) = bc!(resid, sol, p, t)
86-
@views function eval_bc_residual!(resid, ::TwoPointBVProblem, (bca!, bcb!), sol, p, t)
84+
function eval_bc_residual!(resid, pt, bc!::BC, sol, p) where {BC}
85+
return eval_bc_residual!(resid, pt, bc!, sol, p, sol.t)
86+
end
87+
eval_bc_residual!(resid, _, bc!::BC, sol, p, t) where {BC} = bc!(resid, sol, p, t)
88+
@views function eval_bc_residual!(resid, ::TwoPointBVProblem, (bca!, bcb!)::BC, sol, p,
89+
t) where {BC}
8790
ua = sol isa AbstractVector ? sol[1] : sol(first(t))
8891
ub = sol isa AbstractVector ? sol[end] : sol(last(t))
8992
bca!(resid.resida, ua, p)
9093
bcb!(resid.residb, ub, p)
9194
return resid
9295
end
93-
@views function eval_bc_residual!(resid::Tuple, ::TwoPointBVProblem, (bca!, bcb!), sol, p,
94-
t)
96+
@views function eval_bc_residual!(resid::Tuple, ::TwoPointBVProblem, (bca!, bcb!)::BC, sol,
97+
p, t) where {BC}
9598
ua = sol isa AbstractVector ? sol[1] : sol(first(t))
9699
ub = sol isa AbstractVector ? sol[end] : sol(last(t))
97100
bca!(resid[1], ua, p)

0 commit comments

Comments
 (0)