@@ -30,6 +30,26 @@ for f in [:rowvals, :nonzeros, :getcolptr]
30
30
@eval SparseArrays.$ (f)(A:: ThreadedSparseMatrixCSC ) = SparseArrays.$ (f)(A. A)
31
31
end
32
32
33
+ function mul! (C:: StridedVecOrMat , A:: ThreadedSparseMatrixCSC , B:: Union{StridedVector,AdjOrTransStridedOrTriangularMatrix} , α:: Number , β:: Number )
34
+ size (A, 2 ) == size (B, 1 ) || throw (DimensionMismatch ())
35
+ size (A, 1 ) == size (C, 1 ) || throw (DimensionMismatch ())
36
+ size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
37
+ nzv = nonzeros (A)
38
+ rv = rowvals (A)
39
+ if β != 1
40
+ β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
41
+ end
42
+ Threads. @threads for k = 1 : size (C, 2 )
43
+ @inbounds for col = 1 : size (A, 2 )
44
+ αxj = B[col,k] * α
45
+ for j = getcolptr (A)[col]: (getcolptr (A)[col + 1 ] - 1 )
46
+ C[rv[j], k] += nzv[j]* αxj
47
+ end
48
+ end
49
+ end
50
+ C
51
+ end
52
+
33
53
function mul! (C:: StridedVecOrMat , adjA:: Adjoint{<:Any,<:ThreadedSparseMatrixCSC} , B:: Union{StridedVector,AdjOrTransStridedOrTriangularMatrix} , α:: Number , β:: Number )
34
54
A = adjA. parent
35
55
size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
@@ -41,15 +61,13 @@ function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}
41
61
if β != 1
42
62
β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
43
63
end
44
- for k = 1 : size (C, 2 )
45
- Threads. @threads for col = 1 : size (A, 2 )
46
- @inbounds begin
47
- tmp = zero (eltype (C))
48
- for j = colptrA[col]: (colptrA[col+ 1 ] - 1 )
49
- tmp += adjoint (nzv[j])* B[rv[j],k]
50
- end
51
- C[col,k] += α* tmp
64
+ Threads. @threads for k = 1 : size (C, 2 )
65
+ @inbounds for col = 1 : size (A, 2 )
66
+ tmp = zero (eltype (C))
67
+ for j = getcolptr (A)[col]: (getcolptr (A)[col + 1 ] - 1 )
68
+ tmp += adjoint (nzv[j])* B[rv[j],k]
52
69
end
70
+ C[col,k] += tmp * α
53
71
end
54
72
end
55
73
C
0 commit comments