Skip to content

Commit 7f4b8b3

Browse files
Threaded Adjoint times vector
1 parent 8a94d20 commit 7f4b8b3

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

src/ThreadedSparseArrays.jl

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVe
7070
C
7171
end
7272

73-
function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
73+
function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix, α::Number, β::Number)
7474
A = adjA.parent
7575
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
7676
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
@@ -94,8 +94,31 @@ function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}
9494
end
9595
C
9696
end
97+
function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::StridedVector, α::Number, β::Number)
98+
A = adjA.parent
99+
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
100+
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
101+
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
102+
@assert size(B,2)==1
103+
colptrA = getcolptr(A)
104+
nzv = nonzeros(A)
105+
rv = rowvals(A)
106+
if β != 1
107+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
108+
end
109+
@sync for r in RangeIterator(size(A,2), Threads.nthreads())
110+
Threads.@spawn @inbounds for col = r
111+
tmp = zero(eltype(C))
112+
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
113+
tmp += adjoint(nzv[j])*B[rv[j]]
114+
end
115+
C[col] += tmp * α
116+
end
117+
end
118+
C
119+
end
97120

98-
function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
121+
function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix, α::Number, β::Number)
99122
A = transA.parent
100123
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
101124
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
@@ -118,6 +141,28 @@ function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:ThreadedSparseMatrix
118141
end
119142
C
120143
end
144+
function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::StridedVector, α::Number, β::Number)
145+
A = transA.parent
146+
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
147+
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
148+
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
149+
@assert size(B,2)==1
150+
nzv = nonzeros(A)
151+
rv = rowvals(A)
152+
if β != 1
153+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
154+
end
155+
@sync for r in RangeIterator(size(A,2), Threads.nthreads())
156+
Threads.@spawn @inbounds for col = r
157+
tmp = zero(eltype(C))
158+
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
159+
tmp += transpose(nzv[j])*B[rv[j]]
160+
end
161+
C[col] += tmp * α
162+
end
163+
end
164+
C
165+
end
121166

122167
function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::ThreadedSparseMatrixCSC, α::Number, β::Number)
123168
mX, nX = size(X)

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,14 @@ using Test
3939
@test norm(ref-out) == 0
4040
end
4141

42+
x = rand(0:1,N)
43+
@testset "$(Mat)_L_$(op)_vec" for op in [adjoint,transpose], Mat in [ThreadedSparseMatrixCSC]
44+
Ct = Mat(C)
45+
46+
out = zeros(T, n)
47+
LinearAlgebra.mul!(out, op(Ct), x)
48+
ref = op(C)*x
49+
@test norm(ref-out) == 0
50+
end
51+
4252
end

0 commit comments

Comments
 (0)