Skip to content

Commit 00d8e04

Browse files
authored
Merge pull request #131 from avik-pal/ap/shooting_cache
2 parents e760b8d + d5db8a1 commit 00d8e04

File tree

8 files changed

+340
-121
lines changed

8 files changed

+340
-121
lines changed

ext/BoundaryValueDiffEqOrdinaryDiffEqExt.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ end
4444

4545
if @load_preference("PrecompileShooting", true)
4646
push!(algs,
47-
Shooting(Tsit5();
48-
nlsolve = NewtonRaphson(; autodiff = AutoForwardDiff(chunksize = 2))))
47+
Shooting(Tsit5(); nlsolve = NewtonRaphson(),
48+
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))))
4949
end
5050

5151
if @load_preference("PrecompileMultipleShooting", true)
@@ -108,11 +108,10 @@ end
108108
if @load_preference("PrecompileShootingNLLS", VERSIONv"1.10-")
109109
append!(algs,
110110
[
111-
Shooting(Tsit5();
112-
nlsolve = LevenbergMarquardt(;
113-
autodiff = AutoForwardDiff(chunksize = 2))),
114-
Shooting(Tsit5();
115-
nlsolve = GaussNewton(; autodiff = AutoForwardDiff(chunksize = 2))),
111+
Shooting(Tsit5(); nlsolve = LevenbergMarquardt(),
112+
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))),
113+
Shooting(Tsit5(); nlsolve = GaussNewton(),
114+
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2))),
116115
])
117116
end
118117

src/algorithms.jl

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ abstract type BoundaryValueDiffEqAlgorithm <: SciMLBase.AbstractBVPAlgorithm end
33
abstract type AbstractMIRK <: BoundaryValueDiffEqAlgorithm end
44

55
"""
6-
Shooting(ode_alg; nlsolve = NewtonRaphson())
6+
Shooting(ode_alg; nlsolve = NewtonRaphson(), jac_alg = BVPJacobianAlgorithm())
77
88
Single shooting method, reduces BVP to an initial value problem and solves the IVP.
99
@@ -15,19 +15,50 @@ Single shooting method, reduces BVP to an initial value problem and solves the I
1515
## Keyword Arguments
1616
1717
- `nlsolve`: Internal Nonlinear solver. Any solver which conforms to the SciML
18-
`NonlinearProblem` interface can be used.
18+
`NonlinearProblem` interface can be used. Note that any autodiff argument for the solver
19+
will be ignored and a custom jacobian algorithm will be used.
20+
- `jac_alg`: Jacobian Algorithm used for the nonlinear solver. Defaults to
21+
`BVPJacobianAlgorithm()`, which automatically decides the best algorithm to use based
22+
on the input types and problem type. Only `diffmode` is used (defaults to
23+
`AutoForwardDiff` if possible else `AutoFiniteDiff`).
1924
2025
!!! note
21-
For type-stability, you need to specify the chunksize for autodiff. This can be done
22-
via `NewtonRaphson(; autodiff = AutoForwardDiff(; chunksize = <chunksize>))`.
23-
Alternatively, you can use other ADTypes!
26+
For type-stability, the chunksizes for ForwardDiff ADTypes in `BVPJacobianAlgorithm`
27+
must be provided.
2428
"""
25-
struct Shooting{O, N} <: BoundaryValueDiffEqAlgorithm
29+
struct Shooting{O, N, L <: BVPJacobianAlgorithm} <: BoundaryValueDiffEqAlgorithm
2630
ode_alg::O
2731
nlsolve::N
32+
jac_alg::L
2833
end
2934

30-
Shooting(ode_alg; nlsolve = NewtonRaphson()) = Shooting(ode_alg, nlsolve)
35+
function concretize_jacobian_algorithm(alg::Shooting, prob)
36+
jac_alg = alg.jac_alg
37+
diffmode = jac_alg.diffmode === nothing ? __default_nonsparse_ad(prob.u0) :
38+
jac_alg.diffmode
39+
return Shooting(alg.ode_alg, alg.nlsolve, BVPJacobianAlgorithm(diffmode))
40+
end
41+
42+
function Shooting(ode_alg; nlsolve = NewtonRaphson(), jac_alg = nothing)
43+
jac_alg === nothing && (jac_alg = __propagate_nlsolve_ad_to_jac_alg(nlsolve))
44+
return Shooting(ode_alg, nlsolve, jac_alg)
45+
end
46+
47+
Shooting(ode_alg, nlsolve; jac_alg = nothing) = Shooting(ode_alg; nlsolve, jac_alg)
48+
49+
# This is a deprecation path. We forward the `ad` from nonlinear solver to `jac_alg`.
50+
# We will drop this function in
51+
function __propagate_nlsolve_ad_to_jac_alg(nlsolve::N) where {N}
52+
# Defaults so no depwarn
53+
nlsolve === nothing && return BVPJacobianAlgorithm()
54+
ad = hasfield(N, :ad) ? nlsolve.ad : nothing
55+
ad === nothing && return BVPJacobianAlgorithm()
56+
57+
Base.depwarn("Setting autodiff to the nonlinear solver in Shooting has been deprecated \
58+
and will have no effect from the next major release. Update to use \
59+
`BVPJacobianAlgorithm` directly", :Shooting)
60+
return BVPJacobianAlgorithm(ad)
61+
end
3162

3263
"""
3364
MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),

0 commit comments

Comments
 (0)