Skip to content

Commit 1568352

Browse files
committed
Added ThreadedColumnizedSparseMatrix
1 parent 48c54f4 commit 1568352

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

src/ThreadedSparseArrays.jl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ import LinearAlgebra: mul!
55
using SparseArrays
66
import SparseArrays: AdjOrTransStridedOrTriangularMatrix, getcolptr
77

8+
# * ThreadedSparseMatrixCSC
9+
10+
"""
11+
ThreadedSparseMatrixCSC(A)
12+
13+
Thin container around `A::SparseMatrixCSC` that will enable certain
14+
threaded multiplications of `A` with dense matrices.
15+
"""
816
struct ThreadedSparseMatrixCSC{Tv,Ti,At} <: AbstractSparseMatrix{Tv,Ti}
917
A::At
1018
ThreadedSparseMatrixCSC(A::At) where {Tv,Ti,At<:AbstractSparseMatrix{Tv,Ti}} =
@@ -87,6 +95,42 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::Thr
8795
C
8896
end
8997

90-
export ThreadedSparseMatrixCSC
98+
# * ThreadedColumnizedSparseMatrix
99+
100+
"""
101+
ThreadedColumnizedSparseMatrix(columns, m, n)
102+
103+
Sparse matrix of size `m×n` where the `columns` are stored separately,
104+
enabling threaded multiplication. Seems faster than
105+
[`ThreadedSparseMatrixCSC`](@ref) for some use cases.
106+
"""
107+
struct ThreadedColumnizedSparseMatrix{Tv,Ti,Columns} <: AbstractSparseMatrix{Tv,Ti}
108+
columns::Columns
109+
m::Int
110+
n::Int
111+
ThreadedColumnizedSparseMatrix(::Type{Tv}, ::Type{Ti}, columns::Columns, m, n) where {Tv,Ti,Columns} =
112+
new{Tv,Ti,Columns}(columns, m, n)
113+
end
114+
115+
function ThreadedColumnizedSparseMatrix(A::AbstractSparseMatrix{Tv,Ti}) where {Tv,Ti}
116+
m,n = size(A)
117+
Column = typeof(A[:,1])
118+
columns = Column[A[:,j] for j = 1:n]
119+
ThreadedColumnizedSparseMatrix(Tv, Ti, columns, m, n)
120+
end
121+
122+
Base.size(A::ThreadedColumnizedSparseMatrix) = (A.m,A.n)
123+
Base.size(A::ThreadedColumnizedSparseMatrix,i) = size(A)[i]
124+
Base.getindex(A::ThreadedColumnizedSparseMatrix, i, j) = A.columns[j][i]
125+
126+
function LinearAlgebra.mul!(A::AbstractMatrix, B::AbstractMatrix, C::ThreadedColumnizedSparseMatrix,
127+
α::Number=true, β::Number=false)
128+
Threads.@threads for j = 1:C.n
129+
LinearAlgebra.mul!(view(A, :, j), B, C.columns[j], α, β)
130+
end
131+
A
132+
end
133+
134+
export ThreadedSparseMatrixCSC, ThreadedColumnizedSparseMatrix
91135

92136
end # module

test/runtests.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ using Test
99
T = ComplexF64
1010

1111
C = sprand(T, N, n, 0.05)
12-
Ct = ThreadedSparseMatrixCSC(C)
12+
@testset "$(Mat)" for Mat in [ThreadedSparseMatrixCSC, ThreadedColumnizedSparseMatrix]
13+
Ct = Mat(C)
1314

14-
eye = Matrix(one(T)*I, N, N)
15-
out = zeros(T, N, n)
16-
LinearAlgebra.mul!(out, eye, Ct)
17-
ref = eye*C
18-
@test norm(ref-out) == 0
15+
eye = Matrix(one(T)*I, N, N)
16+
out = zeros(T, N, n)
17+
LinearAlgebra.mul!(out, eye, Ct)
18+
ref = eye*C
19+
@test norm(ref-out) == 0
20+
end
1921
end

0 commit comments

Comments
 (0)