Skip to content

Commit 33491e0

Browse files
added diagonal-sparse multiplication (#564)
Co-authored-by: Daniel Karrasch <[email protected]>
1 parent 8f02b7f commit 33491e0

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/linalg.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,23 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B
188188
C
189189
end
190190

191+
function *(A::Diagonal, b::AbstractSparseVector)
192+
if size(A, 2) != length(b)
193+
throw(
194+
DimensionMismatch(lazy"The dimension of the matrix A $(size(A)) and of the vector b $(length(b))")
195+
)
196+
end
197+
T = promote_eltype(A, b)
198+
res = similar(b, T)
199+
nzind_b = nonzeroinds(b)
200+
nzval_b = nonzeros(b)
201+
nzval_res = nonzeros(res)
202+
for idx in eachindex(nzind_b)
203+
nzval_res[idx] = A.diag[nzind_b[idx]] * nzval_b[idx]
204+
end
205+
return res
206+
end
207+
191208
# Sparse matrix multiplication as described in [Gustavson, 1978]:
192209
# http://dl.acm.org/citation.cfm?id=355796
193210

test/linalg.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,21 @@ end
673673
end
674674
end
675675

676+
@testset "diagonal - sparse vector mutliplication" begin
677+
for _ in 1:10
678+
b = spzeros(10)
679+
b[1:3] .= 1:3
680+
A = Diagonal(randn(10))
681+
@test norm(A * b - A * Vector(b)) <= 10eps()
682+
@test norm(A * b - Array(A) * b) <= 10eps()
683+
Ac = Diagonal(randn(Complex{Float64}, 10))
684+
@test norm(Ac * b - Ac * Vector(b)) <= 10eps()
685+
@test norm(Ac * b - Array(Ac) * b) <= 10eps()
686+
@test_throws DimensionMismatch A * [b; 1]
687+
@test_throws DimensionMismatch A * b[1:end-1]
688+
end
689+
end
690+
676691
@testset "sparse matrix * BitArray" begin
677692
A = sprand(5,5,0.3)
678693
MA = Array(A)

0 commit comments

Comments
 (0)