|
1 | 1 | module ThreadedSparseArrays
|
2 | 2 |
|
| 3 | + |
| 4 | +export ThreadedSparseMatrixCSC |
| 5 | + |
| 6 | + |
3 | 7 | using LinearAlgebra
|
4 | 8 | import LinearAlgebra: mul!
|
5 | 9 | using SparseArrays
|
6 |
| -import SparseArrays: getcolptr |
| 10 | +import SparseArrays: getcolptr, AbstractSparseMatrixCSC |
7 | 11 | const AdjOrTransDenseMatrix = if VERSION < v"1.6.0-rc2"
|
8 | 12 | SparseArrays.AdjOrTransStridedOrTriangularMatrix
|
9 | 13 | else
|
@@ -36,31 +40,18 @@ Base.iterate(it::RangeIterator, i::Int=1) = i>it.k ? nothing : (endpos(it,i-1)+1
|
36 | 40 | Thin container around `A::SparseMatrixCSC` that will enable certain
|
37 | 41 | threaded multiplications of `A` with dense matrices.
|
38 | 42 | """
|
39 |
| -struct ThreadedSparseMatrixCSC{Tv,Ti,At} <: AbstractSparseMatrix{Tv,Ti} |
| 43 | +struct ThreadedSparseMatrixCSC{Tv,Ti,At} <: AbstractSparseMatrixCSC{Tv,Ti} |
40 | 44 | A::At
|
41 |
| - ThreadedSparseMatrixCSC(A::At) where {Tv,Ti,At<:AbstractSparseMatrix{Tv,Ti}} = |
| 45 | + ThreadedSparseMatrixCSC(A::At) where {Tv,Ti,At<:AbstractSparseMatrixCSC{Tv,Ti}} = |
42 | 46 | new{Tv,Ti,At}(A)
|
43 | 47 | end
|
44 | 48 |
|
45 | 49 | Base.size(A::ThreadedSparseMatrixCSC, args...) = size(A.A, args...)
|
46 |
| -Base.eltype(A::ThreadedSparseMatrixCSC) = eltype(A.A) |
47 |
| -Base.getindex(A::ThreadedSparseMatrixCSC, args...) = getindex(A.A, args...) |
48 |
| - |
49 |
| -# Need to override printing |
50 |
| -# Need to forward findnz, etc |
51 | 50 |
|
52 | 51 | for f in [:rowvals, :nonzeros, :getcolptr]
|
53 | 52 | @eval SparseArrays.$(f)(A::ThreadedSparseMatrixCSC) = SparseArrays.$(f)(A.A)
|
54 | 53 | end
|
55 | 54 |
|
56 |
| -# For non-threaded implementations, fallback to sparse methods and not generic matmul. |
57 |
| -mul!(C::AbstractVector, A::ThreadedSparseMatrixCSC, B::AbstractVector, α::Number, β::Number) = mul!(C, A.A, B, α, β) |
58 |
| -mul!(C::AbstractMatrix, A::ThreadedSparseMatrixCSC, B::AbstractMatrix, α::Number, β::Number) = mul!(C, A.A, B, α, β) |
59 |
| -mul!(C::AbstractVector, A::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::AbstractVector, α::Number, β::Number) = mul!(C, adjoint(A.parent.A), B, α, β) |
60 |
| -mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::AbstractMatrix, α::Number, β::Number) = mul!(C, adjoint(A.parent.A), B, α, β) |
61 |
| -mul!(C::AbstractVector, A::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::AbstractVector, α::Number, β::Number) = mul!(C, transpose(A.parent.A), B, α, β) |
62 |
| -mul!(C::AbstractMatrix, A::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::AbstractMatrix, α::Number, β::Number) = mul!(C, transpose(A.parent.A), B, α, β) |
63 |
| - |
64 | 55 | function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVector,AdjOrTransDenseMatrix}, α::Number, β::Number)
|
65 | 56 | size(A, 2) == size(B, 1) || throw(DimensionMismatch())
|
66 | 57 | size(A, 1) == size(C, 1) || throw(DimensionMismatch())
|
@@ -201,42 +192,4 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, A::ThreadedSparseMat
|
201 | 192 | C
|
202 | 193 | end
|
203 | 194 |
|
204 |
| -# * ThreadedColumnizedSparseMatrix |
205 |
| - |
206 |
| -""" |
207 |
| - ThreadedColumnizedSparseMatrix(columns, m, n) |
208 |
| -
|
209 |
| -Sparse matrix of size `m×n` where the `columns` are stored separately, |
210 |
| -enabling threaded multiplication. Seems faster than |
211 |
| -[`ThreadedSparseMatrixCSC`](@ref) for some use cases. |
212 |
| -""" |
213 |
| -struct ThreadedColumnizedSparseMatrix{Tv,Ti,Columns} <: AbstractSparseMatrix{Tv,Ti} |
214 |
| - columns::Columns |
215 |
| - m::Int |
216 |
| - n::Int |
217 |
| - ThreadedColumnizedSparseMatrix(::Type{Tv}, ::Type{Ti}, columns::Columns, m, n) where {Tv,Ti,Columns} = |
218 |
| - new{Tv,Ti,Columns}(columns, m, n) |
219 |
| -end |
220 |
| - |
221 |
| -function ThreadedColumnizedSparseMatrix(A::AbstractSparseMatrix{Tv,Ti}) where {Tv,Ti} |
222 |
| - m,n = size(A) |
223 |
| - Column = typeof(A[:,1]) |
224 |
| - columns = Column[A[:,j] for j = 1:n] |
225 |
| - ThreadedColumnizedSparseMatrix(Tv, Ti, columns, m, n) |
226 |
| -end |
227 |
| - |
228 |
| -Base.size(A::ThreadedColumnizedSparseMatrix) = (A.m,A.n) |
229 |
| -Base.size(A::ThreadedColumnizedSparseMatrix,i) = size(A)[i] |
230 |
| -Base.getindex(A::ThreadedColumnizedSparseMatrix, i, j) = A.columns[j][i] |
231 |
| - |
232 |
| -function LinearAlgebra.mul!(A::AbstractMatrix, B::AbstractMatrix, C::ThreadedColumnizedSparseMatrix, |
233 |
| - α::Number=true, β::Number=false) |
234 |
| - Threads.@threads for j = 1:C.n |
235 |
| - LinearAlgebra.mul!(view(A, :, j), B, C.columns[j], α, β) |
236 |
| - end |
237 |
| - A |
238 |
| -end |
239 |
| - |
240 |
| -export ThreadedSparseMatrixCSC, ThreadedColumnizedSparseMatrix |
241 |
| - |
242 | 195 | end # module
|
0 commit comments