Skip to content

Commit 4b210c6

Browse files
Merge pull request #120 from avik-pal/ap/ms_improvements
Minor Updates to Multiple Shooting
2 parents 9611bb7 + 0eb5281 commit 4b210c6

File tree

5 files changed

+22
-24
lines changed

5 files changed

+22
-24
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ jobs:
1010
test:
1111
runs-on: ubuntu-latest
1212
strategy:
13+
fail-fast: false
1314
matrix:
1415
group:
1516
- Shooting
1617
- MIRK
1718
- Others
1819
version:
1920
- '1'
21+
- '~1.10.0-0'
2022
steps:
2123
- uses: actions/checkout@v4
2224
- uses: julia-actions/setup-julia@v1

src/solve/multiple_shooting.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
99
iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0)
1010

1111
__alg = concretize_jacobian_algorithm(_alg, prob)
12-
alg = if has_initial_guess && Nig != __alg.nshoots + 1
12+
alg = if has_initial_guess && Nig != __alg.nshoots
1313
verbose &&
14-
@warn "Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(Nig - 1)`"
15-
update_nshoots(__alg, Nig - 1)
14+
@warn "Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(Nig)`"
15+
update_nshoots(__alg, Nig)
1616
else
1717
__alg
1818
end
@@ -57,13 +57,7 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
5757
compute_bc_residual! = if prob.problem_type isa TwoPointBVProblem
5858
@views function compute_bc_residual_tp!(resid_bc, us::ArrayPartition, p,
5959
cur_nshoots, nodes, resid_nodes::Union{Nothing, MaybeDiffCache} = nothing)
60-
ua, ub0 = us.x
61-
# Just Recompute the last ODE Solution
62-
lastodeprob = ODEProblem{iip}(f, reshape(ub0, u0_size),
63-
(nodes[end - 1], nodes[end]), p)
64-
sol_ode_last = __solve(lastodeprob, alg.ode_alg; odesolve_kwargs..., verbose,
65-
kwargs..., save_everystep = false, saveat = (), save_end = true)
66-
ub = vec(sol_ode_last.u[end])
60+
ua, ub = us.x
6761

6862
resid_bc_a, resid_bc_b = if resid_bc isa ArrayPartition
6963
resid_bc.x
@@ -147,7 +141,7 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
147141
resida, residb = resid_bc.x
148142
J_bc[1:length(resida), 1:N] .= J_bc′[1:length(resida), 1:N]
149143
idxᵢ = (length(resida) + 1):(length(resida) + length(residb))
150-
J_bc[idxᵢ, (end - 2N + 1):(end - N)] .= J_bc′[idxᵢ, (end - N + 1):end]
144+
J_bc[idxᵢ, (end - N + 1):end] .= J_bc′[idxᵢ, (end - N + 1):end]
151145

152146
return nothing
153147
end
@@ -215,7 +209,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
215209

216210
bc_jac_cache = (bc_jac_cache_partial, init_jacobian(bc_jac_cache_partial))
217211

218-
jac_prototype = if @isdefined(J_full)
212+
jac_prototype = if alg.jac_alg.nonbc_diffmode isa AbstractSparseADType ||
213+
alg.jac_alg.bc_diffmode isa AbstractSparseADType
219214
J_full
220215
else
221216
__zeros_like(u_at_nodes, length(resid_prototype), length(u_at_nodes))
@@ -236,7 +231,7 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
236231
nodes); resid_prototype, jac = jac_fn, jac_prototype)
237232
nlprob = NonlinearProblem(loss_function!, u_at_nodes, prob.p)
238233
sol_nlsolve = __solve(nlprob, alg.nlsolve; nlsolve_kwargs..., verbose, kwargs...)
239-
# u_at_nodes = sol_nlsolve.u
234+
u_at_nodes = sol_nlsolve.u::typeof(u0)
240235
end
241236

242237
single_shooting_prob = remake(prob; u0 = reshape(u_at_nodes[1:N], u0_size))

src/sparse_jacobians.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function __generate_sparse_jacobian_prototype(alg::MultipleShooting, ::TwoPointB
198198

199199
J_full[(L₁ + L₂ + 1):end, :] .= J_c.M
200200
J_full[1:L₁, 1:N] .= J_bc.M[1:L₁, 1:N]
201-
J_full[(L₁ + 1):(L₁ + L₂), (end - 2N + 1):(end - N)] .= J_bc.M[(L₁ + 1):(L₁ + L₂),
201+
J_full[(L₁ + 1):(L₁ + L₂), (end - N + 1):end] .= J_bc.M[(L₁ + 1):(L₁ + L₂),
202202
(N + 1):(2N)]
203203

204204
return J_full, J_c, J_bc

test/interpolation_test.jl renamed to test/mirk/interpolation_test.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ function prob_bvp_linear_bc!(res, u, p, t)
1414
res[1] = u[1][1] - 1
1515
res[2] = u[end][1]
1616
end
17-
prob_bvp_linear_function = ODEFunction(prob_bvp_linear_f!, analytic = prob_bvp_linear_analytic)
17+
prob_bvp_linear_function = ODEFunction(prob_bvp_linear_f!,
18+
analytic = prob_bvp_linear_analytic)
1819
prob_bvp_linear_tspan = (0.0, 1.0)
1920
prob_bvp_linear = BVProblem(prob_bvp_linear_function, prob_bvp_linear_bc!,
2021
[1.0, 0.0], prob_bvp_linear_tspan, λ)
@@ -28,6 +29,6 @@ end
2829
@testset "Interpolation" begin
2930
@testset "MIRK$order" for order in (2, 3, 4, 5, 6)
3031
@time sol = solve(prob_bvp_linear, mirk_solver(Val(order)); dt = 0.001)
31-
@test sol(0.001) [0.998687464, -1.312035941] atol=testTol
32+
@test sol(0.001)[0.998687464, -1.312035941] atol=testTol
3233
end
3334
end

test/runtests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ const GROUP = uppercase(get(ENV, "GROUP", "ALL"))
1111
@time @safetestset "Ray Tracing BVP" begin
1212
include("shooting/ray_tracing.jl")
1313
end
14-
@time @safetestset "Orbital" begin
15-
include("shooting/orbital.jl")
14+
@static if VERSION v"1.10.0-beta2"
15+
# Orbital Tests take extremely long to compile on Julia 1.9
16+
@time @safetestset "Orbital" begin
17+
include("shooting/orbital.jl")
18+
end
1619
end
1720
end
1821
end
@@ -28,6 +31,9 @@ const GROUP = uppercase(get(ENV, "GROUP", "ALL"))
2831
@time @safetestset "Vector of Vector" begin
2932
include("mirk/vectorofvector_initials.jl")
3033
end
34+
@time @safetestset "Interpolation Tests" begin
35+
include("mirk/interpolation_test.jl")
36+
end
3137
end
3238
end
3339

@@ -46,10 +52,4 @@ const GROUP = uppercase(get(ENV, "GROUP", "ALL"))
4652
end
4753
end
4854
end
49-
50-
@time @testset "Interpolation Tests" begin
51-
@time @safetestset "MIRK Interpolation Test" begin
52-
include("interpolation_test.jl")
53-
end
54-
end
5555
end

0 commit comments

Comments
 (0)