Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit 1f1c066

Browse files
committed
Specialize sparse for various operators
1 parent cb7d20d commit 1f1c066

File tree

5 files changed

+80
-7
lines changed

5 files changed

+80
-7
lines changed

src/composite_operators.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
# operator types are lazy and maintain the structure used to build them.
44

55

6+
# Define a helper function `sparse1` that handles
7+
# `DiffEqArrayOperator` and `DiffEqScaledOperator`.
8+
# We should define `sparse` for these types in `SciMLBase` instead,
9+
# but that package doesn't know anything about sparse arrays yet, so
10+
# we'll introduce a temporary work-around here.
11+
sparse1(A) = sparse(A)
12+
sparse1(A::DiffEqArrayOperator) = sparse1(A.A)
13+
sparse1(A::DiffEqScaledOperator) = A.coeff * sparse1(A.op)
14+
15+
616
# Linear Combination
717
struct DiffEqOperatorCombination{T,O<:Tuple{Vararg{AbstractDiffEqLinearOperator{T}}},
818
C<:AbstractVector{T}} <: AbstractDiffEqCompositeOperator{T}
@@ -13,7 +23,7 @@ struct DiffEqOperatorCombination{T,O<:Tuple{Vararg{AbstractDiffEqLinearOperator{
1323
for i in 2:length(ops)
1424
@assert size(ops[i]) == size(ops[1]) "Operators must be of the same size to be combined! Mismatch between $(ops[i]) and $(ops[i-1]), which are operators $i and $(i-1) respectively"
1525
end
16-
if cache == nothing
26+
if cache === nothing
1727
cache = zeros(T, size(ops[1], 1))
1828
end
1929
new{T,typeof(ops),typeof(cache)}(ops, cache)
@@ -36,6 +46,7 @@ getops(L::DiffEqOperatorCombination) = L.ops
3646
Matrix(L::DiffEqOperatorCombination) = sum(Matrix, L.ops)
3747
convert(::Type{AbstractMatrix}, L::DiffEqOperatorCombination) =
3848
sum(op -> convert(AbstractMatrix, op), L.ops)
49+
SparseArrays.sparse(L::DiffEqOperatorCombination) = sum(sparse1, L.ops)
3950

4051
size(L::DiffEqOperatorCombination, args...) = size(L.ops[1], args...)
4152
getindex(L::DiffEqOperatorCombination, i::Int) = sum(op -> op[i], L.ops)
@@ -64,7 +75,7 @@ struct DiffEqOperatorComposition{T,O<:Tuple{Vararg{AbstractDiffEqLinearOperator{
6475
@assert size(ops[i-1], 1) == size(ops[i], 2) "Operations do not have compatible sizes! Mismatch between $(ops[i]) and $(ops[i-1]), which are operators $i and $(i-1) respectively."
6576
end
6677

67-
if caches == nothing
78+
if caches === nothing
6879
# Construct a list of caches to be used by mul! and ldiv!
6980
caches = []
7081
for op in ops[1:end-1]
@@ -89,6 +100,7 @@ getops(L::DiffEqOperatorComposition) = L.ops
89100
Matrix(L::DiffEqOperatorComposition) = prod(Matrix, reverse(L.ops))
90101
convert(::Type{AbstractMatrix}, L::DiffEqOperatorComposition) =
91102
prod(op -> convert(AbstractMatrix, op), reverse(L.ops))
103+
SparseArrays.sparse(L::DiffEqOperatorComposition) = prod(sparse1, reverse(L.ops))
92104

93105
size(L::DiffEqOperatorComposition) = (size(L.ops[end], 1), size(L.ops[1], 2))
94106
size(L::DiffEqOperatorComposition, m::Integer) = size(L)[m]

src/derivative_operators/concretization.jl

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,68 @@ end
4545
LinearAlgebra.Array(A::DerivativeOperator{T}, N::Int=A.len) where T =
4646
copyto!(zeros(T, N, N+2), A, N)
4747

48-
SparseArrays.SparseMatrixCSC(A::DerivativeOperator{T}, N::Int=A.len) where T =
49-
copyto!(spzeros(T, N, N+2), A, N)
48+
function SparseArrays.SparseMatrixCSC(A::DerivativeOperator{T}, N::Int=A.len) where T
49+
bl = A.boundary_point_count
50+
stencil_length = A.stencil_length
51+
stencil_pivot = use_winding(A) ? (1 + stencil_length%2) : div(stencil_length,2)
52+
bstl = A.boundary_stencil_length
53+
54+
coeff = A.coefficients
55+
get_coeff = if coeff isa AbstractVector
56+
i -> coeff[i]
57+
elseif coeff isa Number
58+
i -> coeff
59+
else
60+
i -> true
61+
end
62+
63+
Is = Int[]
64+
Js = Int[]
65+
Vs = T[]
66+
67+
nvalues = 2*bl * bstl + (N - 2*bl) * stencil_length
68+
sizehint!(Is, nvalues)
69+
sizehint!(Js, nvalues)
70+
sizehint!(Vs, nvalues)
71+
72+
for i in 1:bl
73+
cur_coeff = get_coeff(i)
74+
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(A.low_boundary_coefs[i]) : A.low_boundary_coefs[i]
75+
append!(Is, ((i for j in 1:bstl)...))
76+
append!(Js, 1:bstl)
77+
append!(Vs, cur_coeff * cur_stencil)
78+
end
79+
80+
for i in bl+1:N-bl
81+
cur_coeff = get_coeff(i)
82+
stencil = eltype(A.stencil_coefs) <: AbstractVector ? A.stencil_coefs[i-bl] : A.stencil_coefs
83+
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(stencil) : stencil
84+
append!(Is, ((i for j in 1:stencil_length)...))
85+
append!(Js, i+1-stencil_pivot:i-stencil_pivot+stencil_length)
86+
append!(Vs, cur_coeff * cur_stencil)
87+
end
88+
89+
for i in N-bl+1:N
90+
cur_coeff = get_coeff(i)
91+
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(A.high_boundary_coefs[i-N+bl]) : A.high_boundary_coefs[i-N+bl]
92+
append!(Is, ((i for j in N-bstl+3:N+2)...))
93+
append!(Js, N-bstl+3:N+2)
94+
append!(Vs, cur_coeff * cur_stencil)
95+
end
96+
97+
# ensure efficient allocation
98+
@assert length(Is) == nvalues
99+
@assert length(Js) == nvalues
100+
@assert length(Vs) == nvalues
101+
102+
return sparse(Is, Js, Vs, N, N+2)
103+
end
50104

51105
SparseArrays.sparse(A::DerivativeOperator{T}, N::Int=A.len) where T = SparseMatrixCSC(A,N)
52106

107+
Base.copyto!(L::AbstractSparseArray{T}, A::DerivativeOperator{T}, N::Int) where T =
108+
copyto!(L, sparse(A))
109+
53110
function BandedMatrices.BandedMatrix(A::DerivativeOperator{T}, N::Int=A.len) where T
54111
stencil_length = A.stencil_length
55112
bstl = A.boundary_stencil_length

src/derivative_operators/derivative_operator_functions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
183183
opsA = DerivativeOperator[]
184184
opsB = DerivativeOperator[]
185185
for L in A.ops
186-
if (L.coefficients isa Number || L.coefficients === nothing) && use_winding(L) == false && L.dx isa Number
186+
if (L.coefficients isa Number || L.coefficients === nothing) && use_winding(L) === false && L.dx isa Number
187187
push!(opsA, L)
188188
else
189189
push!(opsB,L)
@@ -416,7 +416,7 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,3}, A::AbstractDiffEqComposi
416416
opsA = DerivativeOperator[]
417417
opsB = DerivativeOperator[]
418418
for L in A.ops
419-
if (L.coefficients isa Number || L.coefficients === nothing) && use_winding(L) == false && L.dx isa Number
419+
if (L.coefficients isa Number || L.coefficients === nothing) && use_winding(L) === false && L.dx isa Number
420420
push!(opsA, L)
421421
else
422422
push!(opsB,L)

test/DerivativeOperators/composite_operators_interface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test, LinearAlgebra, Random, DiffEqOperators
1+
using Test, LinearAlgebra, Random, SparseArrays, DiffEqOperators
22
using DiffEqBase
33
using DiffEqBase: isconstant
44
using DiffEqOperators: DiffEqScaledOperator, DiffEqOperatorCombination, DiffEqOperatorComposition
@@ -22,6 +22,8 @@ using DiffEqOperators: DiffEqScaledOperator, DiffEqOperatorCombination, DiffEqOp
2222
@test opnorm(L) opnorm(Lfull)
2323
@test size(L) == size(Lfull)
2424
@test L[1,2] Lfull[1,2]
25+
Lsparse = sparse(L)
26+
@test Lsparse == Lfull
2527
u = [1.0, 2.0]; du = zeros(2)
2628
@test L * u Lfull * u
2729
mul!(du, L, u); @test du Lfull * u

test/DerivativeOperators/generic_operator_validation.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ for dor in 1:4, aor in 2:2:6
1414
Dr = CenteredDifference(dor,aor,dx[1],length(x)-2)
1515
Dir = CenteredDifference(dor,aor,dx,length(x)-2)
1616

17+
@test sparse(Dr)==Array(Dr)
18+
1719
@test sparse(Dr)sparse(Dir)
1820
@test Array(Dr)Array(Dir)
1921

0 commit comments

Comments
 (0)