Skip to content

Commit e9c8a4c

Browse files
Merge pull request #124 from avik-pal/ap/banded
Setup for using Banded Structures in BVP
2 parents 4b210c6 + 89bcde3 commit e9c8a4c

File tree

11 files changed

+238
-308
lines changed

11 files changed

+238
-308
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "BoundaryValueDiffEq"
22
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
3-
version = "5.1.0"
3+
version = "5.2.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
910
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1011
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1112
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -18,7 +19,6 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1819
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1920
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2021
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
21-
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2222
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
2323
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2424

@@ -32,6 +32,7 @@ BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"
3232
ADTypes = "0.2"
3333
Adapt = "3"
3434
ArrayInterface = "7"
35+
BandedMatrices = "1"
3536
ConcreteStructs = "0.2"
3637
DiffEqBase = "6.94.2"
3738
ForwardDiff = "0.10"

src/BoundaryValueDiffEq.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
module BoundaryValueDiffEq
22

3-
using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays, SciMLBase,
4-
Static, RecursiveArrayTools, ForwardDiff
3+
using Adapt, BandedMatrices, ForwardDiff, LinearAlgebra, PreallocationTools,
4+
RecursiveArrayTools, Reexport, Setfield, SparseArrays
55
@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase
66

77
import ADTypes: AbstractADType
8-
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix
8+
import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scalar_indexing
99
import ConcreteStructs: @concrete
1010
import DiffEqBase: solve
1111
import ForwardDiff: pickchunksize
1212
import RecursiveArrayTools: ArrayPartition, DiffEqArray
13-
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve
14-
import RecursiveArrayTools: ArrayPartition
13+
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
1514
import SparseDiffTools: AbstractSparseADType
1615
import TruncatedStacktraces: @truncate_stacktrace
1716
import UnPack: @unpack

src/solve/mirk.jl

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
fᵢ₂_cache
2626
defect
2727
new_stages
28+
resid_size
2829
kwargs
2930
end
3031

@@ -64,8 +65,13 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
6465
bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X)
6566

6667
residual = if iip
67-
vcat([__alloc_diffcache(bcresid_prototype)],
68-
__alloc_diffcache.(copy.(@view(y₀[2:end]))))
68+
if prob.problem_type isa TwoPointBVProblem
69+
vcat([__alloc_diffcache(__vec(bcresid_prototype))],
70+
__alloc_diffcache.(copy.(@view(y₀[2:end]))))
71+
else
72+
vcat([__alloc_diffcache(bcresid_prototype)],
73+
__alloc_diffcache.(copy.(@view(y₀[2:end]))))
74+
end
6975
else
7076
nothing
7177
end
@@ -74,6 +80,7 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
7480
new_stages = [similar(X, ifelse(adaptive, M, 0)) for _ in 1:n]
7581

7682
# Transform the functions to handle non-vector inputs
83+
bcresid_prototype = __vec(bcresid_prototype)
7784
f, bc = if X isa AbstractVector
7885
prob.f, prob.f.bc
7986
elseif iip
@@ -92,7 +99,6 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
9299
end
93100
(__vecbc_a!, __vecbc_b!)
94101
end
95-
bcresid_prototype = vec(bcresid_prototype)
96102
vecf!, vecbc!
97103
else
98104
vecf(u, p, t) = vec(prob.f(reshape(u, size(X)), p, t))
@@ -103,14 +109,13 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
103109
__vecbc_b(ub, p) = vec(prob.f.bc[2](reshape(ub, size(X)), p))
104110
(__vecbc_a, __vecbc_b)
105111
end
106-
bcresid_prototype = vec(bcresid_prototype)
107112
vecf, vecbc
108113
end
109114

110115
return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob,
111116
prob.problem_type, prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt,
112117
k_discrete, k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages,
113-
(; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs...))
118+
resid₁_size, (; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs...))
114119
end
115120

116121
"""
@@ -224,13 +229,21 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {
224229
end
225230

226231
loss = if iip
227-
function loss_internal!(resid::AbstractVector, u::AbstractVector, p = cache.p)
232+
@views function loss_internal!(resid::AbstractVector,
233+
u::AbstractVector,
234+
p = cache.p)
228235
y_ = recursive_unflatten!(cache.y, u)
229236
resids = [get_tmp(r, u) for r in cache.residual]
230-
eval_bc_residual!(resids[1], cache.problem_type, cache.bc, y_, p, cache.mesh)
237+
resid_bc = if cache.problem_type isa TwoPointBVProblem
238+
(resids[1][1:prod(cache.resid_size[1])],
239+
resids[1][(prod(cache.resid_size[1]) + 1):end])
240+
else
241+
resids[1]
242+
end
243+
eval_bc_residual!(resid_bc, cache.problem_type, cache.bc, y_, p, cache.mesh)
231244
Φ!(resids[2:end], cache, y_, u, p)
232245
if cache.problem_type isa TwoPointBVProblem
233-
recursive_flatten_twopoint!(resid, resids)
246+
recursive_flatten_twopoint!(resid, resids, cache.resid_size)
234247
else
235248
recursive_flatten!(resid, resids)
236249
end
@@ -242,7 +255,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {
242255
resid_bc = eval_bc_residual(cache.problem_type, cache.bc, y_, p, cache.mesh)
243256
resid_co = Φ(cache, y_, u, p)
244257
if cache.problem_type isa TwoPointBVProblem
245-
return vcat(resid_bc.x[1], mapreduce(vec, vcat, resid_co), resid_bc.x[2])
258+
return vcat(resid_bc[1], mapreduce(vec, vcat, resid_co), resid_bc[2])
246259
else
247260
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
248261
end
@@ -268,7 +281,7 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
268281

269282
sd_collocation = if jac_alg.nonbc_diffmode isa AbstractSparseADType
270283
PrecomputedJacobianColorvec(__generate_sparse_jacobian_prototype(cache,
271-
cache.problem_type, y, cache.M, N))
284+
cache.problem_type, y, y, cache.M, N))
272285
else
273286
NoSparsityDetection()
274287
end
@@ -299,19 +312,20 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocati
299312
return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p)
300313
end
301314

302-
function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, loss,
303-
::TwoPointBVProblem) where {iip}
315+
function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocation,
316+
loss, ::TwoPointBVProblem) where {iip}
304317
@unpack nlsolve, jac_alg = cache.alg
305318
N = length(cache.mesh)
306319

307-
resid = ArrayPartition(cache.bcresid_prototype, similar(y, cache.M * (N - 1)))
320+
resid = vcat(cache.bcresid_prototype[1:prod(cache.resid_size[1])],
321+
similar(y, cache.M * (N - 1)),
322+
cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])
308323

309-
# TODO: We can splitup the computation here as well similar to the Multiple Shooting
310-
# TODO: code. That way for the BC part the actual jacobian computation is even cheaper
311-
# TODO: Remember to not reorder if we end up using that implementation
312324
sd = if jac_alg.diffmode isa AbstractSparseADType
313325
PrecomputedJacobianColorvec(__generate_sparse_jacobian_prototype(cache,
314-
cache.problem_type, resid.x[1], cache.M, N))
326+
cache.problem_type, @view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]),
327+
@view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]), cache.M,
328+
N))
315329
else
316330
NoSparsityDetection()
317331
end

0 commit comments

Comments
 (0)