@@ -52,6 +52,15 @@ for f in [:rowvals, :nonzeros, :getcolptr]
52
52
@eval SparseArrays.$ (f)(A:: ThreadedSparseMatrixCSC ) = SparseArrays.$ (f)(A. A)
53
53
end
54
54
55
+
56
+ # sparse * sparse multiplications are not (currently) threaded, but we want to keep the return type
57
+ for (T1,t1) in ((ThreadedSparseMatrixCSC,identity), (Adjoint{<: Any ,<: ThreadedSparseMatrixCSC },adjoint), (Transpose{<: Any ,<: ThreadedSparseMatrixCSC },transpose))
58
+ for (T2,t2) in ((ThreadedSparseMatrixCSC,identity), (Adjoint{<: Any ,<: ThreadedSparseMatrixCSC },adjoint), (Transpose{<: Any ,<: ThreadedSparseMatrixCSC },transpose))
59
+ @eval Base.:(* )(A:: $T1 , B:: $T2 ) = ThreadedSparseMatrixCSC ($ t1 ($ t1 (A). A)* $ t2 ($ t2 (B). A))
60
+ end
61
+ end
62
+
63
+
55
64
function mul! (C:: StridedVecOrMat , A:: ThreadedSparseMatrixCSC , B:: Union{StridedVector,AdjOrTransDenseMatrix} , α:: Number , β:: Number )
56
65
size (A, 2 ) == size (B, 1 ) || throw (DimensionMismatch ())
57
66
size (A, 1 ) == size (C, 1 ) || throw (DimensionMismatch ())
@@ -63,9 +72,9 @@ function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVe
63
72
end
64
73
@sync for r in RangeIterator (size (C,2 ), Threads. nthreads ())
65
74
Threads. @spawn for k in r
66
- @inbounds for col = 1 : size (A, 2 )
75
+ @inbounds for col in 1 : size (A, 2 )
67
76
αxj = B[col,k] * α
68
- for j = getcolptr (A)[col] : ( getcolptr (A)[col + 1 ] - 1 )
77
+ for j in nzrange (A, col )
69
78
C[rv[j], k] += nzv[j]* αxj
70
79
end
71
80
end
@@ -74,98 +83,53 @@ function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVe
74
83
C
75
84
end
76
85
77
- function mul! (C:: StridedVecOrMat , adjA:: Adjoint{<:Any,<:ThreadedSparseMatrixCSC} , B:: AdjOrTransDenseMatrix , α:: Number , β:: Number )
78
- A = adjA. parent
79
- size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
80
- size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
81
- size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
82
- colptrA = getcolptr (A)
83
- nzv = nonzeros (A)
84
- rv = rowvals (A)
85
- if β != 1
86
- β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
87
- end
88
- @sync for r in RangeIterator (size (C,2 ), Threads. nthreads ())
89
- Threads. @spawn for k in r
90
- @inbounds for col = 1 : size (A, 2 )
91
- tmp = zero (eltype (C))
92
- for j = getcolptr (A)[col]: (getcolptr (A)[col + 1 ] - 1 )
93
- tmp += adjoint (nzv[j])* B[rv[j],k]
94
- end
95
- C[col,k] += tmp * α
96
- end
86
+ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
87
+ @eval function mul! (C:: StridedVecOrMat , xA:: $T{<:Any,<:ThreadedSparseMatrixCSC} , B:: AdjOrTransDenseMatrix , α:: Number , β:: Number )
88
+ A = xA. parent
89
+ size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
90
+ size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
91
+ size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
92
+ nzv = nonzeros (A)
93
+ rv = rowvals (A)
94
+ if β != 1
95
+ β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
97
96
end
98
- end
99
- C
100
- end
101
- function mul! (C:: StridedVecOrMat , adjA:: Adjoint{<:Any,<:ThreadedSparseMatrixCSC} , B:: StridedVector , α:: Number , β:: Number )
102
- A = adjA. parent
103
- size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
104
- size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
105
- size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
106
- @assert size (B,2 )== 1
107
- colptrA = getcolptr (A)
108
- nzv = nonzeros (A)
109
- rv = rowvals (A)
110
- if β != 1
111
- β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
112
- end
113
- @sync for r in RangeIterator (size (A,2 ), Threads. nthreads ())
114
- Threads. @spawn @inbounds for col = r
115
- tmp = zero (eltype (C))
116
- for j = getcolptr (A)[col]: (getcolptr (A)[col + 1 ] - 1 )
117
- tmp += adjoint (nzv[j])* B[rv[j]]
97
+ @sync for r in RangeIterator (size (C,2 ), Threads. nthreads ())
98
+ Threads. @spawn for k in r
99
+ @inbounds for col in 1 : size (A, 2 )
100
+ tmp = zero (eltype (C))
101
+ for j in nzrange (A, col)
102
+ tmp += $ t (nzv[j])* B[rv[j],k]
103
+ end
104
+ C[col,k] += tmp * α
105
+ end
118
106
end
119
- C[col] += tmp * α
120
107
end
108
+ C
121
109
end
122
- C
123
- end
124
110
125
- function mul! (C:: StridedVecOrMat , transA :: Transpose {<:Any,<:ThreadedSparseMatrixCSC} , B:: AdjOrTransDenseMatrix , α:: Number , β:: Number )
126
- A = transA . parent
127
- size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
128
- size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
129
- size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
130
- nzv = nonzeros (A)
131
- rv = rowvals (A)
132
- if β != 1
133
- β != 0 ? rmul! (C, β) : fill! (C, zero ( eltype (C)))
134
- end
135
- @sync for r in RangeIterator ( size (C, 2 ), Threads . nthreads ())
136
- Threads . @spawn for k in r
137
- @ inbounds for col = 1 : size (A, 2 )
111
+ @eval function mul! (C:: StridedVecOrMat , xA :: $T {<:Any,<:ThreadedSparseMatrixCSC} , B:: StridedVector , α:: Number , β:: Number )
112
+ A = xA . parent
113
+ size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
114
+ size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
115
+ size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
116
+ @assert size (B, 2 ) == 1
117
+ nzv = nonzeros (A)
118
+ rv = rowvals (A)
119
+ if β != 1
120
+ β != 0 ? rmul! (C, β) : fill! (C, zero ( eltype (C)))
121
+ end
122
+ @sync for r in RangeIterator ( size (A, 2 ), Threads . nthreads ())
123
+ Threads . @spawn @ inbounds for col in r
138
124
tmp = zero (eltype (C))
139
- for j = getcolptr (A)[col] : ( getcolptr (A)[col + 1 ] - 1 )
140
- tmp += transpose (nzv[j])* B[rv[j],k ]
125
+ for j in nzrange (A, col )
126
+ tmp += $ t (nzv[j])* B[rv[j]]
141
127
end
142
- C[col,k ] += tmp * α
128
+ C[col] += tmp * α
143
129
end
144
130
end
131
+ C
145
132
end
146
- C
147
- end
148
- function mul! (C:: StridedVecOrMat , transA:: Transpose{<:Any,<:ThreadedSparseMatrixCSC} , B:: StridedVector , α:: Number , β:: Number )
149
- A = transA. parent
150
- size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
151
- size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
152
- size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
153
- @assert size (B,2 )== 1
154
- nzv = nonzeros (A)
155
- rv = rowvals (A)
156
- if β != 1
157
- β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
158
- end
159
- @sync for r in RangeIterator (size (A,2 ), Threads. nthreads ())
160
- Threads. @spawn @inbounds for col = r
161
- tmp = zero (eltype (C))
162
- for j = getcolptr (A)[col]: (getcolptr (A)[col + 1 ] - 1 )
163
- tmp += transpose (nzv[j])* B[rv[j]]
164
- end
165
- C[col] += tmp * α
166
- end
167
- end
168
- C
169
133
end
170
134
171
135
function mul! (C:: StridedVecOrMat , X:: AdjOrTransDenseMatrix , A:: ThreadedSparseMatrixCSC , α:: Number , β:: Number )
@@ -178,18 +142,47 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, A::ThreadedSparseMat
178
142
if β != 1
179
143
β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
180
144
end
145
+ # TODO : split in X isa DenseMatrixUnion and X isa Adjoint/Transpose so we can use @simd in the first case (see original code in SparseArrays)
181
146
@sync for r in RangeIterator (size (A,2 ), Threads. nthreads ())
182
147
Threads. @spawn for col in r
183
- @inbounds for k= getcolptr (A)[ col] : ( getcolptr (A)[col + 1 ] - 1 )
184
- j = rv [k]
185
- αv = nzv [k]* α
186
- for multivec_row= 1 : mX
187
- C[multivec_row, col] += X[multivec_row, j ] * αv
148
+ @inbounds for k in nzrange (A, col)
149
+ Aiα = nzv [k] * α
150
+ rvk = rv [k]
151
+ for multivec_row in 1 : mX
152
+ C[multivec_row, col] += X[multivec_row, rvk ] * Aiα
188
153
end
189
154
end
190
155
end
191
156
end
192
157
C
193
158
end
194
159
160
+ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
161
+ @eval function mul! (C:: StridedVecOrMat , X:: AdjOrTransDenseMatrix , xA:: $T{<:Any,<:ThreadedSparseMatrixCSC} , α:: Number , β:: Number )
162
+ A = xA. parent
163
+ mX, nX = size (X)
164
+ nX == size (A, 2 ) || throw (DimensionMismatch ())
165
+ mX == size (C, 1 ) || throw (DimensionMismatch ())
166
+ size (A, 1 ) == size (C, 2 ) || throw (DimensionMismatch ())
167
+ rv = rowvals (A)
168
+ nzv = nonzeros (A)
169
+ if β != 1
170
+ β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
171
+ end
172
+
173
+ # transpose of Threaded * Dense algorithm above
174
+ @sync for r in RangeIterator (size (C,1 ), Threads. nthreads ())
175
+ Threads. @spawn for k in r
176
+ @inbounds for col in 1 : size (A, 2 )
177
+ αxj = X[k,col] * α
178
+ for j in nzrange (A, col)
179
+ C[k, rv[j]] += $ t (nzv[j])* αxj
180
+ end
181
+ end
182
+ end
183
+ end
184
+ C
185
+ end
186
+ end
187
+
195
188
end # module
0 commit comments