Skip to content

Commit 9d2b4d7

Browse files
committed
Try reducing runtime dispatches
1 parent 60536ee commit 9d2b4d7

File tree

2 files changed

+97
-65
lines changed

2 files changed

+97
-65
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ version = "1.2.0"
559559

560560
[[deps.NonlinearSolve]]
561561
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "EnumX", "FiniteDiff", "ForwardDiff", "LineSearches", "LinearAlgebra", "LinearSolve", "PrecompileTools", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "UnPack"]
562-
git-tree-sha1 = "bc8998f278128f4ef91b140127aa014ec20e3b33"
562+
git-tree-sha1 = "ed0c7e75b1e6fac8e87d037a366c0094d49d1904"
563563
repo-rev = "ap/fixes"
564564
repo-url = "https://github.com/avik-pal/NonlinearSolve.jl"
565565
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"

src/solve/multiple_shooting.jl

Lines changed: 96 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -41,76 +41,107 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
4141
cur_nshoot, all_nshoots[i - 1], ig; kwargs..., verbose, odesolve_kwargs...)
4242
end
4343

44-
if __any_sparse_ad(alg.jac_alg)
45-
J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type,
46-
bcresid_prototype, u0, N, cur_nshoot)
44+
if prob.problem_type isa TwoPointBVProblem
45+
__solve_nlproblem!(alg, bcresid_prototype, u_at_nodes, nodes, cur_nshoot, N,
46+
resida_len, residb_len, Val(iip), solve_internal_odes!, bc[1], bc[2], prob,
47+
u0; verbose, kwargs..., nlsolve_kwargs...)
48+
else
49+
__solve_nlproblem!(alg, bcresid_prototype, u_at_nodes, nodes, cur_nshoot, N,
50+
prod(resid_size), Val(iip), solve_internal_odes!, bc, prob, f, u0_size, u0;
51+
verbose, kwargs..., nlsolve_kwargs...)
4752
end
53+
end
4854

49-
if prob.problem_type isa TwoPointBVProblem
50-
resid_prototype = vcat(bcresid_prototype[1],
51-
similar(u_at_nodes, cur_nshoot * N), bcresid_prototype[2])
55+
single_shooting_prob = remake(prob; u0 = reshape(u_at_nodes[1:N], u0_size))
56+
return __solve(single_shooting_prob, Shooting(alg.ode_alg; alg.nlsolve);
57+
odesolve_kwargs, nlsolve_kwargs, verbose, kwargs...)
58+
end
5259

53-
__resid_nodes = resid_prototype[(resida_len + 1):(resida_len + cur_nshoot * N)]
54-
resid_nodes = __maybe_allocate_diffcache(__resid_nodes,
55-
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.diffmode)
60+
function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes, nodes,
61+
cur_nshoot, N, resida_len, residb_len, iip::Val, solve_internal_odes!::S, bca::B1,
62+
bcb::B2, prob, u0; kwargs...) where {B1, B2, S}
63+
if __any_sparse_ad(alg.jac_alg)
64+
J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type,
65+
bcresid_prototype, u0, N, cur_nshoot)
66+
end
5667

57-
loss_fn = (du, u, p) -> __multiple_shooting_2point_loss!(du, u, p, cur_nshoot,
58-
nodes, Val(iip), solve_internal_odes!, resida_len, residb_len, N, bc[1],
59-
bc[2])
60-
loss_fnₚ = (du, u) -> loss_fn(du, u, prob.p)
68+
resid_prototype = vcat(bcresid_prototype[1],
69+
similar(u_at_nodes, cur_nshoot * N), bcresid_prototype[2])
6170

62-
sd_bvp = alg.jac_alg.diffmode isa AbstractSparseADType ?
63-
__sparsity_detection_alg(J_proto) : NoSparsityDetection()
71+
__resid_nodes = resid_prototype[(resida_len + 1):(resida_len + cur_nshoot * N)]
72+
resid_nodes = __maybe_allocate_diffcache(__resid_nodes,
73+
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.diffmode)
6474

65-
resid_prototype_cached = similar(resid_prototype)
66-
jac_cache = sparse_jacobian_cache(alg.jac_alg.diffmode, sd_bvp, loss_fnₚ,
67-
resid_prototype_cached, u_at_nodes)
68-
jac_prototype = init_jacobian(jac_cache)
75+
loss_fn = (du, u, p) -> __multiple_shooting_2point_loss!(du, u, p, cur_nshoot,
76+
nodes, iip, solve_internal_odes!, resida_len, residb_len, N, bca,
77+
bcb)
78+
loss_fnₚ = (du, u) -> loss_fn(du, u, prob.p)
6979

70-
jac_fn = (J, u, p) -> __multiple_shooting_2point_jacobian!(J, u, p, jac_cache,
71-
loss_fnₚ, resid_prototype_cached, alg)
72-
else
73-
resid_prototype = vcat(bcresid_prototype, similar(u_at_nodes, cur_nshoot * N))
74-
75-
__resid_nodes = resid_prototype[(end - cur_nshoot * N + 1):end]
76-
resid_nodes = __maybe_allocate_diffcache(__resid_nodes,
77-
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode)
78-
79-
loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot,
80-
nodes, Val(iip), solve_internal_odes!, prod(resid_size), N, f, bc, u0_size,
81-
tspan, alg.ode_alg)
82-
83-
ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes)
84-
sd_ode = alg.jac_alg.nonbc_diffmode isa AbstractSparseADType ?
85-
__sparsity_detection_alg(J_proto) : NoSparsityDetection()
86-
ode_jac_cache = sparse_jacobian_cache(alg.jac_alg.nonbc_diffmode, sd_ode,
87-
ode_fn, similar(u_at_nodes, cur_nshoot * N), u_at_nodes)
88-
89-
bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc!(du, u, prob.p,
90-
cur_nshoot, nodes, Val(iip), solve_internal_odes!, N, f, bc, u0_size, tspan,
91-
alg.ode_alg)
92-
sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ?
93-
SymbolicsSparsityDetection() : NoSparsityDetection()
94-
bc_jac_cache = sparse_jacobian_cache(alg.jac_alg.bc_diffmode,
95-
sd_bc, bc_fn, similar(bcresid_prototype), u_at_nodes)
96-
97-
jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache))
98-
99-
jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p,
100-
similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
101-
ode_fn, bc_fn, alg, N)
102-
end
103-
loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
104-
jac_prototype)
80+
sd_bvp = alg.jac_alg.diffmode isa AbstractSparseADType ?
81+
__sparsity_detection_alg(J_proto) : NoSparsityDetection()
82+
83+
resid_prototype_cached = similar(resid_prototype)
84+
jac_cache = sparse_jacobian_cache(alg.jac_alg.diffmode, sd_bvp, loss_fnₚ,
85+
resid_prototype_cached, u_at_nodes)
86+
jac_prototype = init_jacobian(jac_cache)
87+
88+
jac_fn = (J, u, p) -> __multiple_shooting_2point_jacobian!(J, u, p, jac_cache,
89+
loss_fnₚ, resid_prototype_cached, alg)
90+
91+
loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
92+
jac_prototype)
93+
94+
# NOTE: u_at_nodes is updated inplace
95+
nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p)
96+
__solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true)
10597

106-
# NOTE: u_at_nodes is updated inplace
107-
nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p)
108-
__solve(nlprob, alg.nlsolve; verbose, kwargs..., nlsolve_kwargs..., alias_u0 = true)
98+
return nothing
99+
end
100+
101+
function __solve_nlproblem!(alg::MultipleShooting, bcresid_prototype, u_at_nodes, nodes,
102+
cur_nshoot, N, resid_size, iip::Val, solve_internal_odes!::S, bc::BC,
103+
prob, f::F, u0_size, u0; kwargs...) where {BC, F, S}
104+
if __any_sparse_ad(alg.jac_alg)
105+
J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type,
106+
bcresid_prototype, u0, N, cur_nshoot)
109107
end
108+
resid_prototype = vcat(bcresid_prototype, similar(u_at_nodes, cur_nshoot * N))
110109

111-
single_shooting_prob = remake(prob; u0 = reshape(u_at_nodes[1:N], u0_size))
112-
return __solve(single_shooting_prob, Shooting(alg.ode_alg; alg.nlsolve);
113-
odesolve_kwargs, nlsolve_kwargs, verbose, kwargs...)
110+
__resid_nodes = resid_prototype[(end - cur_nshoot * N + 1):end]
111+
resid_nodes = __maybe_allocate_diffcache(__resid_nodes,
112+
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode)
113+
114+
loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot,
115+
nodes, iip, solve_internal_odes!, prod(resid_size), N, f, bc, u0_size,
116+
tspan, alg.ode_alg)
117+
118+
ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes)
119+
sd_ode = alg.jac_alg.nonbc_diffmode isa AbstractSparseADType ?
120+
__sparsity_detection_alg(J_proto) : NoSparsityDetection()
121+
ode_jac_cache = sparse_jacobian_cache(alg.jac_alg.nonbc_diffmode, sd_ode,
122+
ode_fn, similar(u_at_nodes, cur_nshoot * N), u_at_nodes)
123+
124+
bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc!(du, u, prob.p,
125+
cur_nshoot, nodes, iip, solve_internal_odes!, N, f, bc, u0_size, tspan, alg.ode_alg)
126+
sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ?
127+
SymbolicsSparsityDetection() : NoSparsityDetection()
128+
bc_jac_cache = sparse_jacobian_cache(alg.jac_alg.bc_diffmode,
129+
sd_bc, bc_fn, similar(bcresid_prototype), u_at_nodes)
130+
131+
jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache))
132+
133+
jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p,
134+
similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
135+
ode_fn, bc_fn, alg, N)
136+
137+
loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
138+
jac_prototype)
139+
140+
# NOTE: u_at_nodes is updated inplace
141+
nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p)
142+
__solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true)
143+
144+
return nothing
114145
end
115146

116147
function __multiple_shooting_solve_internal_odes!(resid_nodes, us, p, ::Val{iip}, f,
@@ -185,8 +216,7 @@ end
185216
end
186217

187218
@views function __multiple_shooting_mpoint_loss_bc!(resid_bc, us, p, cur_nshoots::Int,
188-
nodes,
189-
::Val{iip}, solve_internal_odes!, N, f, bc, u0_size, tspan, ode_alg) where {iip}
219+
nodes, ::Val{iip}, solve_internal_odes!, N, f, bc, u0_size, tspan, ode_alg) where {iip}
190220
_resid_nodes = similar(us, cur_nshoots * N)
191221

192222
# NOTE: We need to recompute this to correctly propagate the dual numbers / gradients
@@ -313,7 +343,9 @@ end
313343
idxs_prev = (N + (fpos - 2) * N .+ (1:N))
314344
ustart = u_at_nodes_prev[idxs_prev]
315345

316-
odeprob = ODEProblem(f, ustart, (t0, tstop), p)
346+
# https://github.com/SciML/DifferentialEquations.jl/issues/975
347+
# odeprob = ODEProblem(f, ustart, (t0, tstop), p)
348+
odeprob = ODEProblem(f, copy(ustart), (t0, tstop), p)
317349
odesol = __solve(odeprob, alg.ode_alg; kwargs..., saveat = (), save_end = true)
318350

319351
u_at_nodes[idxs] .= odesol.u[end]
@@ -324,7 +356,7 @@ end
324356
end
325357

326358
@inline function __get_all_nshoots(g::Bool, nshoots)
327-
return g ? __get_all_nshoots(Base.Fix2(÷, 2)) : [nshoots]
359+
return g ? __get_all_nshoots(Base.Fix2(÷, 2), nshoots) : [nshoots]
328360
end
329361
@inline function __get_all_nshoots(g, nshoots)
330362
first(g) == nshoots && return g

0 commit comments

Comments
 (0)