Skip to content

Commit 154ee9b

Browse files
committed
Temporarily use Cholesky
1 parent d8fdd37 commit 154ee9b

File tree

4 files changed

+20
-20
lines changed

4 files changed

+20
-20
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ ConcreteStructs = "0.2"
4343
DiffEqBase = "6.135"
4444
ForwardDiff = "0.10"
4545
LinearAlgebra = "1.6"
46+
LinearSolve = "2"
4647
NonlinearSolve = "2.5"
4748
ODEInterface = "0.5"
4849
OrdinaryDiffEq = "6"
@@ -62,6 +63,7 @@ julia = "1.9"
6263
[extras]
6364
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6465
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
66+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
6567
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
6668
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
6769
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -70,4 +72,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
7072
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7173

7274
[targets]
73-
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "SafeTestsets", "ODEInterface", "Aqua"]
75+
test = ["StaticArrays", "Random", "DiffEqDevTools", "OrdinaryDiffEq", "Test", "SafeTestsets", "ODEInterface", "Aqua", "LinearSolve"]

src/solve/mirk.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -305,11 +305,11 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
305305
jac = if iip
306306
(J, u, p) -> __mirk_mpoint_jacobian!(J, u, p, jac_alg.bc_diffmode,
307307
jac_alg.nonbc_diffmode, cache_bc, cache_collocation, loss_bcₚ,
308-
loss_collocationₚ, resid_bc, resid_collocation, cache.M)
308+
loss_collocationₚ, resid_bc, resid_collocation, cache.M, L)
309309
else
310310
(u, p) -> __mirk_mpoint_jacobian(u, p, jac_prototype, jac_alg.bc_diffmode,
311311
jac_alg.nonbc_diffmode, cache_bc, cache_collocation, loss_bcₚ,
312-
loss_collocationₚ, cache.M)
312+
loss_collocationₚ, cache.M, L)
313313
end
314314

315315
nlf = NonlinearFunction{iip}(loss; resid_prototype = vcat(resid_bc, resid_collocation),
@@ -319,17 +319,17 @@ end
319319

320320
function __mirk_mpoint_jacobian!(J, x, p, bc_diffmode, nonbc_diffmode, bc_diffcache,
321321
nonbc_diffcache, loss_bc::BC, loss_collocation::C, resid_bc, resid_collocation,
322-
M::Int) where {BC, C}
323-
sparse_jacobian!(@view(J[1:M, :]), bc_diffmode, bc_diffcache, loss_bc, resid_bc, x)
324-
sparse_jacobian!(@view(J[(M + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
322+
M::Int, L::Int) where {BC, C}
323+
sparse_jacobian!(@view(J[1:L, :]), bc_diffmode, bc_diffcache, loss_bc, resid_bc, x)
324+
sparse_jacobian!(@view(J[(L + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
325325
loss_collocation, resid_collocation, x)
326326
return nothing
327327
end
328328

329329
function __mirk_mpoint_jacobian(x, p, J, bc_diffmode, nonbc_diffmode, bc_diffcache,
330-
nonbc_diffcache, loss_bc::BC, loss_collocation::C, M::Int) where {BC, C}
331-
sparse_jacobian!(@view(J[1:M, :]), bc_diffmode, bc_diffcache, loss_bc, x)
332-
sparse_jacobian!(@view(J[(M + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
330+
nonbc_diffcache, loss_bc::BC, loss_collocation::C, M::Int, L::Int) where {BC, C}
331+
sparse_jacobian!(@view(J[1:L, :]), bc_diffmode, bc_diffcache, loss_bc, x)
332+
sparse_jacobian!(@view(J[(L + 1):end, :]), nonbc_diffmode, nonbc_diffcache,
333333
loss_collocation, x)
334334
return J
335335
end
@@ -341,9 +341,9 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo
341341

342342
lossₚ = iip ? ((du, u) -> loss(du, u, cache.p)) : (u -> loss(u, cache.p))
343343

344-
resid = vcat(cache.bcresid_prototype[1:prod(cache.resid_size[1])],
344+
resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
345345
similar(y, cache.M * (N - 1)),
346-
cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])
346+
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]))
347347
L = length(cache.bcresid_prototype)
348348

349349
sd = if jac_alg.diffmode isa AbstractSparseADType

src/solve/multiple_shooting.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_
133133

134134
jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p,
135135
similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
136-
ode_fn, bc_fn, alg, N)
136+
ode_fn, bc_fn, alg, N, M)
137137

138138
loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
139139
jac_prototype)
@@ -186,9 +186,9 @@ end
186186

187187
function __multiple_shooting_mpoint_jacobian!(J, us, p, resid_bc, resid_nodes,
188188
ode_jac_cache, bc_jac_cache, ode_fn::F1, bc_fn::F2, alg::MultipleShooting,
189-
N::Int) where {F1, F2}
190-
J_bc = @view(J[1:N, :])
191-
J_c = @view(J[(N + 1):end, :])
189+
N::Int, M::Int) where {F1, F2}
190+
J_bc = @view(J[1:M, :])
191+
J_c = @view(J[(M + 1):end, :])
192192

193193
sparse_jacobian!(J_c, alg.jac_alg.nonbc_diffmode, ode_jac_cache, ode_fn,
194194
resid_nodes.du, us)

test/shooting/shooting_tests.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using BoundaryValueDiffEq, LinearAlgebra, OrdinaryDiffEq, Test
1+
using BoundaryValueDiffEq, LinearAlgebra, LinearSolve, OrdinaryDiffEq, Test
22

33
@testset "Basic Shooting Tests" begin
44
SOLVERS = [Shooting(Tsit5()), MultipleShooting(10, Tsit5())]
@@ -83,12 +83,10 @@ end
8383
@testset "Overconstrained BVP" begin
8484
SOLVERS = [
8585
Shooting(Tsit5();
86-
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
87-
α_geodesic = 0.9, b_uphill = 2.0)),
86+
nlsolve = LevenbergMarquardt(; linsolve = CholeskyFactorization())),
8887
Shooting(Tsit5(); nlsolve = GaussNewton()),
8988
MultipleShooting(10, Tsit5();
90-
nlsolve = LevenbergMarquardt(; damping_initial = 1e-6,
91-
α_geodesic = 0.9, b_uphill = 2.0)),
89+
nlsolve = LevenbergMarquardt(; linsolve = CholeskyFactorization())),
9290
MultipleShooting(10, Tsit5(); nlsolve = GaussNewton())]
9391

9492
# OOP MP-BVP

0 commit comments

Comments
 (0)