Skip to content

Commit d3ee5a6

Browse files
Made threading composable
1 parent 8ec68c8 commit d3ee5a6

File tree

1 file changed

+54
-29
lines changed

1 file changed

+54
-29
lines changed

src/ThreadedSparseArrays.jl

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,29 @@ import LinearAlgebra: mul!
55
using SparseArrays
66
import SparseArrays: AdjOrTransStridedOrTriangularMatrix, getcolptr
77

8+
# * Threading utilities
9+
struct RangeIterator
10+
k::Int
11+
d::Int
12+
r::Int
13+
end
14+
15+
"""
16+
RangeIterator(n::Int,k::Int)
17+
18+
Returns an iterator splitting the range `1:n` into `min(k,n)` parts of (almost) equal size.
19+
"""
20+
RangeIterator(n::Int,k::Int) = RangeIterator(min(n,k),divrem(n,k)...)
21+
Base.length(it::RangeIterator) = it.k
22+
#Base.iterate(it::RangeIterator, i::Int=1) = i>it.k ? nothing : (((i-1)*it.d+min(i-1,it.r)+1):(i*it.d+min(i,it.r)), i+1)
23+
function Base.iterate(it::RangeIterator, i::Int=1)
24+
i>it.k && return nothing
25+
e = i*it.d + min(i,it.r)
26+
s = e - it.d + (i>it.r)
27+
(s:e,i+1)
28+
end
29+
30+
831
# * ThreadedSparseMatrixCSC
932

1033
"""
@@ -39,11 +62,13 @@ function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVe
3962
if β != 1
4063
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
4164
end
42-
Threads.@threads for k = 1:size(C, 2)
43-
@inbounds for col = 1:size(A, 2)
44-
αxj = B[col,k] * α
45-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
46-
C[rv[j], k] += nzv[j]*αxj
65+
@sync for r in RangeIterator(size(C,2), Threads.nthreads())
66+
Threads.@spawn for k in r
67+
@inbounds for col = 1:size(A, 2)
68+
αxj = B[col,k] * α
69+
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
70+
C[rv[j], k] += nzv[j]*αxj
71+
end
4772
end
4873
end
4974
end
@@ -61,13 +86,15 @@ function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}
6186
if β != 1
6287
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
6388
end
64-
Threads.@threads for k = 1:size(C, 2)
65-
@inbounds for col = 1:size(A, 2)
66-
tmp = zero(eltype(C))
67-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
68-
tmp += adjoint(nzv[j])*B[rv[j],k]
89+
@sync for r in RangeIterator(size(C,2), Threads.nthreads())
90+
Threads.@spawn for k in r
91+
@inbounds for col = 1:size(A, 2)
92+
tmp = zero(eltype(C))
93+
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
94+
tmp += adjoint(nzv[j])*B[rv[j],k]
95+
end
96+
C[col,k] += tmp * α
6997
end
70-
C[col,k] += tmp * α
7198
end
7299
end
73100
C
@@ -83,13 +110,15 @@ function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:ThreadedSparseMatrix
83110
if β != 1
84111
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
85112
end
86-
Threads.@threads for k = 1:size(C, 2)
87-
@inbounds for col = 1:size(A, 2)
88-
tmp = zero(eltype(C))
89-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
90-
tmp += transpose(nzv[j])*B[rv[j],k]
113+
@sync for r in RangeIterator(size(C,2), Threads.nthreads())
114+
Threads.@spawn for k in r
115+
@inbounds for col = 1:size(A, 2)
116+
tmp = zero(eltype(C))
117+
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
118+
tmp += transpose(nzv[j])*B[rv[j],k]
119+
end
120+
C[col,k] += tmp * α
91121
end
92-
C[col,k] += tmp * α
93122
end
94123
end
95124
C
@@ -105,21 +134,17 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::Thr
105134
if β != 1
106135
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
107136
end
108-
# Threads.@threads for col = 1:size(A, 2)
109-
# @inbounds for multivec_row=1:mX, k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
110-
# C[multivec_row, col] += α * X[multivec_row, rv[k]] * nzv[k] # perhaps suboptimal position of α?
111-
# end
112-
# end
113-
Threads.@threads for col = 1:size(A, 2)
114-
@inbounds for k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
115-
j = rv[k]
116-
αv = nzv[k]*α
117-
for multivec_row=1:mX
118-
C[multivec_row, col] += X[multivec_row, j] * αv
137+
@sync for r in RangeIterator(size(A,2), Threads.nthreads())
138+
Threads.@spawn for col in r
139+
@inbounds for k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
140+
j = rv[k]
141+
αv = nzv[k]*α
142+
for multivec_row=1:mX
143+
C[multivec_row, col] += X[multivec_row, j] * αv
144+
end
119145
end
120146
end
121147
end
122-
123148
C
124149
end
125150

0 commit comments

Comments
 (0)