Skip to content

Commit 8068292

Browse files
More mul! methods (#8)
* More mul! methods Code: * Add missing mul! methods for dense times adj(sparse) * ThreadedSparseMatrixCSC times ThreadedSparseMatrixCSC now return ThreadedSparseMatrixCSC (even though the multplication isn't threaded) * Some updates to match upstream code changes Unit tests: * Test which mul! cases are threaded * Use StableRNGs * Test 5-argument mul! * Also use Complex{Int64} (to get tests with exact arithmetic) * Made ComplexF64 test more thorough
1 parent 7da5862 commit 8068292

File tree

3 files changed

+191
-161
lines changed

3 files changed

+191
-161
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1111
julia = "1.4"
1212

1313
[extras]
14+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1416
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1517

1618
[targets]
17-
test = ["Test"]
19+
test = ["Test", "Random", "StableRNGs"]

src/ThreadedSparseArrays.jl

Lines changed: 82 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ for f in [:rowvals, :nonzeros, :getcolptr]
5252
@eval SparseArrays.$(f)(A::ThreadedSparseMatrixCSC) = SparseArrays.$(f)(A.A)
5353
end
5454

55+
56+
# sparse * sparse multiplications are not (currently) threaded, but we want to keep the return type
57+
for (T1,t1) in ((ThreadedSparseMatrixCSC,identity), (Adjoint{<:Any,<:ThreadedSparseMatrixCSC},adjoint), (Transpose{<:Any,<:ThreadedSparseMatrixCSC},transpose))
58+
for (T2,t2) in ((ThreadedSparseMatrixCSC,identity), (Adjoint{<:Any,<:ThreadedSparseMatrixCSC},adjoint), (Transpose{<:Any,<:ThreadedSparseMatrixCSC},transpose))
59+
@eval Base.:(*)(A::$T1, B::$T2) = ThreadedSparseMatrixCSC($t1($t1(A).A)*$t2($t2(B).A))
60+
end
61+
end
62+
63+
5564
function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVector,AdjOrTransDenseMatrix}, α::Number, β::Number)
5665
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
5766
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
@@ -63,9 +72,9 @@ function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVe
6372
end
6473
@sync for r in RangeIterator(size(C,2), Threads.nthreads())
6574
Threads.@spawn for k in r
66-
@inbounds for col = 1:size(A, 2)
75+
@inbounds for col in 1:size(A, 2)
6776
αxj = B[col,k] * α
68-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
77+
for j in nzrange(A, col)
6978
C[rv[j], k] += nzv[j]*αxj
7079
end
7180
end
@@ -74,98 +83,53 @@ function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVe
7483
C
7584
end
7685

77-
function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::AdjOrTransDenseMatrix, α::Number, β::Number)
78-
A = adjA.parent
79-
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
80-
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
81-
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
82-
colptrA = getcolptr(A)
83-
nzv = nonzeros(A)
84-
rv = rowvals(A)
85-
if β != 1
86-
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
87-
end
88-
@sync for r in RangeIterator(size(C,2), Threads.nthreads())
89-
Threads.@spawn for k in r
90-
@inbounds for col = 1:size(A, 2)
91-
tmp = zero(eltype(C))
92-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
93-
tmp += adjoint(nzv[j])*B[rv[j],k]
94-
end
95-
C[col,k] += tmp * α
96-
end
86+
for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
87+
@eval function mul!(C::StridedVecOrMat, xA::$T{<:Any,<:ThreadedSparseMatrixCSC}, B::AdjOrTransDenseMatrix, α::Number, β::Number)
88+
A = xA.parent
89+
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
90+
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
91+
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
92+
nzv = nonzeros(A)
93+
rv = rowvals(A)
94+
if β != 1
95+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
9796
end
98-
end
99-
C
100-
end
101-
function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::StridedVector, α::Number, β::Number)
102-
A = adjA.parent
103-
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
104-
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
105-
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
106-
@assert size(B,2)==1
107-
colptrA = getcolptr(A)
108-
nzv = nonzeros(A)
109-
rv = rowvals(A)
110-
if β != 1
111-
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
112-
end
113-
@sync for r in RangeIterator(size(A,2), Threads.nthreads())
114-
Threads.@spawn @inbounds for col = r
115-
tmp = zero(eltype(C))
116-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
117-
tmp += adjoint(nzv[j])*B[rv[j]]
97+
@sync for r in RangeIterator(size(C,2), Threads.nthreads())
98+
Threads.@spawn for k in r
99+
@inbounds for col in 1:size(A, 2)
100+
tmp = zero(eltype(C))
101+
for j in nzrange(A, col)
102+
tmp += $t(nzv[j])*B[rv[j],k]
103+
end
104+
C[col,k] += tmp * α
105+
end
118106
end
119-
C[col] += tmp * α
120107
end
108+
C
121109
end
122-
C
123-
end
124110

125-
function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::AdjOrTransDenseMatrix, α::Number, β::Number)
126-
A = transA.parent
127-
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
128-
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
129-
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
130-
nzv = nonzeros(A)
131-
rv = rowvals(A)
132-
if β != 1
133-
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
134-
end
135-
@sync for r in RangeIterator(size(C,2), Threads.nthreads())
136-
Threads.@spawn for k in r
137-
@inbounds for col = 1:size(A, 2)
111+
@eval function mul!(C::StridedVecOrMat, xA::$T{<:Any,<:ThreadedSparseMatrixCSC}, B::StridedVector, α::Number, β::Number)
112+
A = xA.parent
113+
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
114+
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
115+
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
116+
@assert size(B,2)==1
117+
nzv = nonzeros(A)
118+
rv = rowvals(A)
119+
if β != 1
120+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
121+
end
122+
@sync for r in RangeIterator(size(A,2), Threads.nthreads())
123+
Threads.@spawn @inbounds for col in r
138124
tmp = zero(eltype(C))
139-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
140-
tmp += transpose(nzv[j])*B[rv[j],k]
125+
for j in nzrange(A, col)
126+
tmp += $t(nzv[j])*B[rv[j]]
141127
end
142-
C[col,k] += tmp * α
128+
C[col] += tmp * α
143129
end
144130
end
131+
C
145132
end
146-
C
147-
end
148-
function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::StridedVector, α::Number, β::Number)
149-
A = transA.parent
150-
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
151-
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
152-
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
153-
@assert size(B,2)==1
154-
nzv = nonzeros(A)
155-
rv = rowvals(A)
156-
if β != 1
157-
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
158-
end
159-
@sync for r in RangeIterator(size(A,2), Threads.nthreads())
160-
Threads.@spawn @inbounds for col = r
161-
tmp = zero(eltype(C))
162-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
163-
tmp += transpose(nzv[j])*B[rv[j]]
164-
end
165-
C[col] += tmp * α
166-
end
167-
end
168-
C
169133
end
170134

171135
function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, A::ThreadedSparseMatrixCSC, α::Number, β::Number)
@@ -178,18 +142,47 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, A::ThreadedSparseMat
178142
if β != 1
179143
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
180144
end
145+
# TODO: split in X isa DenseMatrixUnion and X isa Adjoint/Transpose so we can use @simd in the first case (see original code in SparseArrays)
181146
@sync for r in RangeIterator(size(A,2), Threads.nthreads())
182147
Threads.@spawn for col in r
183-
@inbounds for k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
184-
j = rv[k]
185-
αv = nzv[k]*α
186-
for multivec_row=1:mX
187-
C[multivec_row, col] += X[multivec_row, j] * αv
148+
@inbounds for k in nzrange(A, col)
149+
Aiα = nzv[k] * α
150+
rvk = rv[k]
151+
for multivec_row in 1:mX
152+
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
188153
end
189154
end
190155
end
191156
end
192157
C
193158
end
194159

160+
for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
161+
@eval function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, xA::$T{<:Any,<:ThreadedSparseMatrixCSC}, α::Number, β::Number)
162+
A = xA.parent
163+
mX, nX = size(X)
164+
nX == size(A, 2) || throw(DimensionMismatch())
165+
mX == size(C, 1) || throw(DimensionMismatch())
166+
size(A, 1) == size(C, 2) || throw(DimensionMismatch())
167+
rv = rowvals(A)
168+
nzv = nonzeros(A)
169+
if β != 1
170+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
171+
end
172+
173+
# transpose of Threaded * Dense algorithm above
174+
@sync for r in RangeIterator(size(C,1), Threads.nthreads())
175+
Threads.@spawn for k in r
176+
@inbounds for col in 1:size(A, 2)
177+
αxj = X[k,col] * α
178+
for j in nzrange(A, col)
179+
C[k, rv[j]] += $t(nzv[j])*αxj
180+
end
181+
end
182+
end
183+
end
184+
C
185+
end
186+
end
187+
195188
end # module

0 commit comments

Comments
 (0)