Skip to content

Commit e0512bb

Browse files
committed
Reuse ODE Solver Cache
1 parent e760b8d commit e0512bb

File tree

4 files changed

+120
-31
lines changed

4 files changed

+120
-31
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BoundaryValueDiffEq"
22
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
3-
version = "5.4.0"
3+
version = "6.0.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/BoundaryValueDiffEqOrdinaryDiffEqExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ end
4444

4545
if @load_preference("PrecompileShooting", true)
4646
push!(algs,
47-
Shooting(Tsit5();
48-
nlsolve = NewtonRaphson(; autodiff = AutoForwardDiff(chunksize = 2))))
47+
Shooting(Tsit5(); nlsolve = NewtonRaphson(),
48+
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))))
4949
end
5050

5151
if @load_preference("PrecompileMultipleShooting", true)

src/algorithms.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ abstract type BoundaryValueDiffEqAlgorithm <: SciMLBase.AbstractBVPAlgorithm end
33
abstract type AbstractMIRK <: BoundaryValueDiffEqAlgorithm end
44

55
"""
6-
Shooting(ode_alg; nlsolve = NewtonRaphson())
6+
Shooting(ode_alg; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm())
77
88
Single shooting method, reduces BVP to an initial value problem and solves the IVP.
99
@@ -15,19 +15,33 @@ Single shooting method, reduces BVP to an initial value problem and solves the I
1515
## Keyword Arguments
1616
1717
- `nlsolve`: Internal Nonlinear solver. Any solver which conforms to the SciML
18-
`NonlinearProblem` interface can be used.
18+
`NonlinearProblem` interface can be used.Note that any autodiff argument for the solver
19+
will be ignored and a custom jacobian algorithm will be used.
20+
- `jac_alg`: Jacobian Algorithm used for the nonlinear solver. Defaults to
21+
`BVPJacobianAlgorithm()`, which automatically decides the best algorithm to use based
22+
on the input types and problem type. Only `diffmode` is used (defaults to
23+
`AutoForwardDiff` if possible else `AutoFiniteDiff`).
1924
2025
!!! note
21-
For type-stability, you need to specify the chunksize for autodiff. This can be done
22-
via `NewtonRaphson(; autodiff = AutoForwardDiff(; chunksize = <chunksize>))`.
23-
Alternatively, you can use other ADTypes!
26+
For type-stability, the chunksizes for ForwardDiff ADTypes in `BVPJacobianAlgorithm`
27+
must be provided.
2428
"""
25-
struct Shooting{O, N} <: BoundaryValueDiffEqAlgorithm
29+
struct Shooting{O, N, L <: BVPJacobianAlgorithm} <: BoundaryValueDiffEqAlgorithm
2630
ode_alg::O
2731
nlsolve::N
32+
jac_alg::L
33+
end
34+
35+
function concretize_jacobian_algorithm(alg::Shooting, prob)
36+
jac_alg = alg.jac_alg
37+
diffmode = jac_alg.diffmode === nothing ? __default_nonsparse_ad(prob.u0) :
38+
jac_alg.diffmode
39+
return Shooting(alg.ode_alg, alg.nlsolve, BVPJacobianAlgorithm(diffmode))
2840
end
2941

30-
Shooting(ode_alg; nlsolve = NewtonRaphson()) = Shooting(ode_alg, nlsolve)
42+
function Shooting(ode_alg; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm())
43+
return Shooting(ode_alg, nlsolve, jac_alg)
44+
end
3145

3246
"""
3347
MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),

src/solve/single_shooting.jl

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,141 @@
1-
function __solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;),
1+
function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),
22
nlsolve_kwargs = (;), verbose = true, kwargs...)
33
ig, T, N, _, u0 = __extract_problem_details(prob; dt = 0.1)
44
_unwrap_val(ig) && verbose &&
55
@warn "Initial guess provided, but will be ignored for Shooting!"
66

7+
alg = concretize_jacobian_algorithm(alg_, prob)
8+
79
bcresid_prototype, resid_size = __get_bcresid_prototype(prob, u0)
810
iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0)
911
resid_prototype = __vec(bcresid_prototype)
1012

13+
# Construct the residual function
1114
ode_kwargs = (; kwargs..., verbose, odesolve_kwargs...)
15+
internal_prob = ODEProblem{iip}(prob.f, u0, prob.tspan, prob.p)
16+
ode_cache_loss_fn = SciMLBase.__init(internal_prob, alg.ode_alg; ode_kwargs...)
17+
1218
loss_fn = if iip
13-
(du, u, p) -> __single_shooting_loss!(du, u, p, prob.f, bc, u0_size, prob.tspan,
14-
prob.problem_type, resid_size, alg, ode_kwargs)
19+
(du, u, p) -> __single_shooting_loss!(du, u, p, ode_cache_loss_fn, bc, u0_size,
20+
prob.problem_type, resid_size)
21+
else
22+
(u, p) -> __single_shooting_loss(u, p, ode_cache_loss_fn, bc, u0_size,
23+
prob.problem_type)
24+
end
25+
26+
# Construct the jacobian function
27+
# NOTE: We pass in a separate Jacobian Function because that allows us to cache the
28+
# the internal ode solve cache. This cache needs to be distinct from the regular
29+
# residual function cache
30+
loss_fnₚ = ifelse(iip, (du, u) -> loss_fn(du, u, prob.p), (u) -> loss_fn(u, prob.p))
31+
32+
# TODO: We probably won't be able to support Symbolics through ODE Solver but we should
33+
# be able to allow prespecified coloring.
34+
if alg.jac_alg.diffmode isa AbstractSparseADType
35+
error("Single Shooting doesn't support sparse AD yet!")
36+
end
37+
sd = NoSparsityDetection()
38+
y_ = similar(resid_prototype)
39+
40+
jac_cache = if iip
41+
sparse_jacobian_cache(alg.jac_alg.diffmode, sd, loss_fnₚ, y_, vec(u0))
42+
else
43+
sparse_jacobian_cache(alg.jac_alg.diffmode, sd, loss_fnₚ, vec(u0); fx = y_)
44+
end
45+
46+
ode_cache_jac_fn = __single_shooting_jacobian_ode_cache(internal_prob, jac_cache,
47+
alg.jac_alg.diffmode, vec(u0), alg.ode_alg; ode_kwargs...)
48+
49+
jac_prototype = init_jacobian(jac_cache)
50+
51+
loss_fn2ₚ = if iip
52+
(du, u) -> __single_shooting_loss!(du, u, prob.p, ode_cache_jac_fn, bc, u0_size,
53+
prob.problem_type, resid_size)
54+
else
55+
(u) -> __single_shooting_loss(u, prob.p, ode_cache_jac_fn, bc, u0_size,
56+
prob.problem_type)
57+
end
58+
59+
jac_fn = if iip
60+
(J, u, p) -> __single_shooting_jacobian!(J, u, jac_cache, alg.jac_alg.diffmode,
61+
loss_fn2ₚ, y_)
1562
else
16-
(u, p) -> __single_shooting_loss(u, p, prob.f, bc, u0_size, prob.tspan,
17-
prob.problem_type, alg, ode_kwargs)
63+
(u, p) -> __single_shooting_jacobian(jac_prototype, u, jac_cache,
64+
alg.jac_alg.diffmode, loss_fn2ₚ)
1865
end
1966

20-
nlf = NonlinearFunction{iip}(loss_fn; prob.f.jac_prototype, resid_prototype)
67+
nlf = NonlinearFunction{iip}(loss_fn; prob.f.jac_prototype, resid_prototype,
68+
jac = jac_fn)
2169
nlprob = if length(resid_prototype) == length(u0)
2270
NonlinearProblem(nlf, vec(u0), prob.p)
2371
else
2472
NonlinearLeastSquaresProblem(nlf, vec(u0), prob.p)
2573
end
2674
opt = __solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...)
2775

28-
newprob = ODEProblem{iip}(prob.f, reshape(opt.u, u0_size), prob.tspan, prob.p)
29-
sol = __solve(newprob, alg.ode_alg; odesolve_kwargs..., verbose, kwargs...)
76+
SciMLBase.reinit!(ode_cache_loss_fn, reshape(opt.u, u0_size))
77+
sol = solve!(ode_cache_loss_fn)
3078

3179
!SciMLBase.successful_retcode(opt) &&
3280
return SciMLBase.solution_new_retcode(sol, ReturnCode.Failure)
3381
return sol
3482
end
3583

36-
function __single_shooting_loss!(resid_, u0_, p, f::F, bc::BC, u0_size, tspan,
37-
pt::TwoPointBVProblem, (resida_size, residb_size), alg::Shooting,
38-
kwargs) where {F, BC}
84+
function __single_shooting_loss!(resid_, u0_, p, cache, bc::BC, u0_size,
85+
pt::TwoPointBVProblem, (resida_size, residb_size)) where {BC}
3986
resida = @view resid_[1:prod(resida_size)]
4087
residb = @view resid_[(prod(resida_size) + 1):end]
4188
resid = (reshape(resida, resida_size), reshape(residb, residb_size))
4289

43-
odeprob = ODEProblem{true}(f, reshape(u0_, u0_size), tspan, p)
44-
odesol = __solve(odeprob, alg.ode_alg; kwargs...)
90+
SciMLBase.reinit!(cache, reshape(u0_, u0_size))
91+
odesol = solve!(cache)
92+
4593
eval_bc_residual!(resid, pt, bc, odesol, p)
4694

4795
return nothing
4896
end
4997

50-
function __single_shooting_loss!(resid_, u0_, p, f::F, bc::BC, u0_size, tspan,
51-
pt::StandardBVProblem, resid_size, alg::Shooting, kwargs) where {F, BC}
98+
function __single_shooting_loss!(resid_, u0_, p, cache, bc::BC, u0_size,
99+
pt::StandardBVProblem, resid_size) where {BC}
52100
resid = reshape(resid_, resid_size)
53101

54-
odeprob = ODEProblem{true}(f, reshape(u0_, u0_size), tspan, p)
55-
odesol = __solve(odeprob, alg.ode_alg; kwargs...)
102+
SciMLBase.reinit!(cache, reshape(u0_, u0_size))
103+
odesol = solve!(cache)
104+
56105
eval_bc_residual!(resid, pt, bc, odesol, p)
57106

58107
return nothing
59108
end
60109

61-
function __single_shooting_loss(u0_, p, f::F, bc::BC, u0_size, tspan, pt, alg::Shooting,
62-
kwargs) where {F, BC}
63-
odeprob = ODEProblem{false}(f, reshape(u0_, u0_size), tspan, p)
64-
odesol = __solve(odeprob, alg.ode_alg; kwargs...)
110+
function __single_shooting_loss(u, p, cache, bc::BC, u0_size, pt) where {BC}
111+
SciMLBase.reinit!(cache, reshape(u, u0_size))
112+
odesol = solve!(cache)
65113
return __safe_vec(eval_bc_residual(pt, bc, odesol, p))
66114
end
115+
116+
function __single_shooting_jacobian!(J, u, jac_cache, diffmode, loss_fn::L, fu) where {L}
117+
sparse_jacobian!(J, diffmode, jac_cache, loss_fn, fu, vec(u))
118+
return J
119+
end
120+
121+
function __single_shooting_jacobian(J, u, jac_cache, diffmode, loss_fn::L) where {L}
122+
sparse_jacobian!(J, diffmode, jac_cache, loss_fn, vec(u))
123+
return J
124+
end
125+
126+
function __single_shooting_jacobian_ode_cache(prob, jac_cache, alg, u0, ode_alg; kwargs...)
127+
prob_ = remake(prob; u0)
128+
return SciMLBase.__init(prob_, ode_alg; kwargs...)
129+
end
130+
131+
function __single_shooting_jacobian_ode_cache(prob, jac_cache,
132+
::Union{AutoForwardDiff, AutoSparseForwardDiff}, u0, ode_alg; kwargs...)
133+
cache = jac_cache.cache
134+
if cache isa ForwardDiff.JacobianConfig
135+
xduals = cache.duals isa Tuple ? cache.duals[2] : cache.duals
136+
prob_ = remake(prob; u0 = xduals)
137+
return SciMLBase.__init(prob_, ode_alg; kwargs...)
138+
else
139+
error("Single Shooting doesn't support sparse AD yet!")
140+
end
141+
end

0 commit comments

Comments
 (0)