Skip to content

Commit abf3be0

Browse files
committed
Setup for a specialized Multiple Shooting Algorithm
1 parent 6526a57 commit abf3be0

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

src/algorithms.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ end
6262

6363
"""
6464
MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
65-
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
65+
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm(),
66+
static_auto_nodes::Val = Val(false))
6667
6768
Multiple Shooting method, reduces BVP to an initial value problem and solves the IVP.
6869
Significantly more stable than Single Shooting.
@@ -97,22 +98,25 @@ Significantly more stable than Single Shooting.
9798
- `Function`: Takes the current number of shooting points and returns the next number
9899
of shooting points. For example, if `nshoots = 10` and
99100
`grid_coarsening = n -> n ÷ 2`, then the grid will be coarsened to `[5, 2]`.
101+
- `static_auto_nodes`: Automatically detect the timepoints used in the boundary condition
102+
and use a faster version of the algorithm! This particular keyword argument should be
103+
considered experimental and should be used with care!
100104
101105
!!! note
102106
For type-stability, the chunksizes for ForwardDiff ADTypes in `BVPJacobianAlgorithm`
103107
must be provided.
104108
"""
105-
@concrete struct MultipleShooting{J <: BVPJacobianAlgorithm}
109+
@concrete struct MultipleShooting{S, J <: BVPJacobianAlgorithm}
106110
ode_alg
107111
nlsolve
108112
jac_alg::J
109113
nshoots::Int
110114
grid_coarsening
111115
end
112116

113-
function concretize_jacobian_algorithm(alg::MultipleShooting, prob)
117+
function concretize_jacobian_algorithm(alg::MultipleShooting{S}, prob) where {S}
114118
jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
115-
return MultipleShooting(alg.ode_alg, alg.nlsolve, jac_alg, alg.nshoots,
119+
return MultipleShooting{S}(alg.ode_alg, alg.nlsolve, jac_alg, alg.nshoots,
116120
alg.grid_coarsening)
117121
end
118122

@@ -122,16 +126,18 @@ function update_nshoots(alg::MultipleShooting, nshoots::Int)
122126
end
123127

124128
function MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
125-
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm())
129+
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm(),
130+
static_auto_nodes::Val{S} = Val(false)) where {S}
126131
@assert grid_coarsening isa Bool || grid_coarsening isa Function ||
127132
grid_coarsening isa AbstractVector{<:Integer} ||
128133
grid_coarsening isa NTuple{N, <:Integer} where {N}
134+
@assert S isa Bool
129135
grid_coarsening isa Tuple && (grid_coarsening = Vector(grid_coarsening...))
130136
if grid_coarsening isa AbstractVector
131137
sort!(grid_coarsening; rev = true)
132138
@assert all(grid_coarsening .> 0) && 1 grid_coarsening
133139
end
134-
return MultipleShooting(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening)
140+
return MultipleShooting{S}(ode_alg, nlsolve, jac_alg, nshoots, grid_coarsening)
135141
end
136142

137143
for order in (2, 3, 4, 5, 6)

src/solve/multiple_shooting.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
1-
function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
1+
function __solve(prob::BVProblem, _alg::MultipleShooting{true}; odesolve_kwargs = (;),
2+
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...)
3+
# For TwoPointBVPs there is nothing to do. Forward to general multiple shooting
4+
prob.problem_type isa TwoPointBVProblem &&
5+
return __solve_internal(prob, _alg; kwargs...)
6+
7+
# Extract the time-points used in BC
8+
_prob = ODEProblem{isinplace(prob)}(prob.f, prob.u0, prob.tspan, prob.p)
9+
end
10+
11+
function __solve(prob::BVProblem, _alg::MultipleShooting{false}; kwargs...)
12+
return __solve_internal(prob, _alg; kwargs...)
13+
end
14+
15+
function __solve_internal(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
216
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...)
317
@unpack f, tspan = prob
418

0 commit comments

Comments
 (0)