|
1 | 1 | module ThreadedSparseArrays
|
2 | 2 |
|
3 |
| -greet() = print("Hello World!") |
| 3 | +using LinearAlgebra |
| 4 | +import LinearAlgebra: mul! |
| 5 | +using SparseArrays |
| 6 | +import SparseArrays: AdjOrTransStridedOrTriangularMatrix, getcolptr |
| 7 | + |
| 8 | +struct ThreadedSparseMatrixCSC{Tv,Ti,At} <: AbstractSparseMatrix{Tv,Ti} |
| 9 | + A::At |
| 10 | + ThreadedSparseMatrixCSC(A::At) where {Tv,Ti,At<:AbstractSparseMatrix{Tv,Ti}} = |
| 11 | + new{Tv,Ti,At}(A) |
| 12 | +end |
| 13 | + |
| 14 | +Base.size(A::ThreadedSparseMatrixCSC, args...) = size(A.A, args...) |
| 15 | +Base.eltype(A::ThreadedSparseMatrixCSC) = eltype(A.A) |
| 16 | +Base.getindex(A::ThreadedSparseMatrixCSC, args...) = getindex(A.A, args...) |
| 17 | + |
| 18 | +# Need to override printing |
| 19 | +# Need to forward findnz, etc |
| 20 | + |
| 21 | +for f in [:rowvals, :nonzeros, :getcolptr] |
| 22 | + @eval SparseArrays.$(f)(A::ThreadedSparseMatrixCSC) = SparseArrays.$(f)(A.A) |
| 23 | +end |
| 24 | + |
| 25 | +function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number) |
| 26 | + A = adjA.parent |
| 27 | + size(A, 2) == size(C, 1) || throw(DimensionMismatch()) |
| 28 | + size(A, 1) == size(B, 1) || throw(DimensionMismatch()) |
| 29 | + size(B, 2) == size(C, 2) || throw(DimensionMismatch()) |
| 30 | + colptrA = getcolptr(A) |
| 31 | + nzv = nonzeros(A) |
| 32 | + rv = rowvals(A) |
| 33 | + if β != 1 |
| 34 | + β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) |
| 35 | + end |
| 36 | + for k = 1:size(C, 2) |
| 37 | + Threads.@threads for col = 1:size(A, 2) |
| 38 | + @inbounds begin |
| 39 | + tmp = zero(eltype(C)) |
| 40 | + for j = colptrA[col]:(colptrA[col+1] - 1) |
| 41 | + tmp += adjoint(nzv[j])*B[rv[j],k] |
| 42 | + end |
| 43 | + C[col,k] += α*tmp |
| 44 | + end |
| 45 | + end |
| 46 | + end |
| 47 | + C |
| 48 | +end |
| 49 | + |
| 50 | +function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number) |
| 51 | + A = transA.parent |
| 52 | + size(A, 2) == size(C, 1) || throw(DimensionMismatch()) |
| 53 | + size(A, 1) == size(B, 1) || throw(DimensionMismatch()) |
| 54 | + size(B, 2) == size(C, 2) || throw(DimensionMismatch()) |
| 55 | + nzv = nonzeros(A) |
| 56 | + rv = rowvals(A) |
| 57 | + if β != 1 |
| 58 | + β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) |
| 59 | + end |
| 60 | + Threads.@threads for k = 1:size(C, 2) |
| 61 | + @inbounds for col = 1:size(A, 2) |
| 62 | + tmp = zero(eltype(C)) |
| 63 | + for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1) |
| 64 | + tmp += transpose(nzv[j])*B[rv[j],k] |
| 65 | + end |
| 66 | + C[col,k] += tmp * α |
| 67 | + end |
| 68 | + end |
| 69 | + C |
| 70 | +end |
| 71 | + |
| 72 | +function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::ThreadedSparseMatrixCSC, α::Number, β::Number) |
| 73 | + mX, nX = size(X) |
| 74 | + nX == size(A, 1) || throw(DimensionMismatch()) |
| 75 | + mX == size(C, 1) || throw(DimensionMismatch()) |
| 76 | + size(A, 2) == size(C, 2) || throw(DimensionMismatch()) |
| 77 | + rv = rowvals(A) |
| 78 | + nzv = nonzeros(A) |
| 79 | + if β != 1 |
| 80 | + β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) |
| 81 | + end |
| 82 | + Threads.@threads for col = 1:size(A, 2) |
| 83 | + @inbounds for multivec_row=1:mX, k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1) |
| 84 | + C[multivec_row, col] += α * X[multivec_row, rv[k]] * nzv[k] # perhaps suboptimal position of α? |
| 85 | + end |
| 86 | + end |
| 87 | + C |
| 88 | +end |
| 89 | + |
| 90 | +export ThreadedSparseMatrixCSC |
4 | 91 |
|
5 | 92 | end # module
|
0 commit comments