Skip to content

Commit 8ec68c8

Browse files
better mul! support and performance
1 parent 88eec7a commit 8ec68c8

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

src/ThreadedSparseArrays.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ for f in [:rowvals, :nonzeros, :getcolptr]
3030
@eval SparseArrays.$(f)(A::ThreadedSparseMatrixCSC) = SparseArrays.$(f)(A.A)
3131
end
3232

33+
function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
34+
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
35+
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
36+
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
37+
nzv = nonzeros(A)
38+
rv = rowvals(A)
39+
if β != 1
40+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
41+
end
42+
Threads.@threads for k = 1:size(C, 2)
43+
@inbounds for col = 1:size(A, 2)
44+
αxj = B[col,k] * α
45+
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
46+
C[rv[j], k] += nzv[j]*αxj
47+
end
48+
end
49+
end
50+
C
51+
end
52+
3353
function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
3454
A = adjA.parent
3555
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
@@ -41,15 +61,13 @@ function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}
4161
if β != 1
4262
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
4363
end
44-
for k = 1:size(C, 2)
45-
Threads.@threads for col = 1:size(A, 2)
46-
@inbounds begin
47-
tmp = zero(eltype(C))
48-
for j = colptrA[col]:(colptrA[col+1] - 1)
49-
tmp += adjoint(nzv[j])*B[rv[j],k]
50-
end
51-
C[col,k] += α*tmp
64+
Threads.@threads for k = 1:size(C, 2)
65+
@inbounds for col = 1:size(A, 2)
66+
tmp = zero(eltype(C))
67+
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
68+
tmp += adjoint(nzv[j])*B[rv[j],k]
5269
end
70+
C[col,k] += tmp * α
5371
end
5472
end
5573
C

0 commit comments

Comments
 (0)