diff --git a/src/linalg.jl b/src/linalg.jl index 131a21bc..01e54f71 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -188,6 +188,23 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B C end +function *(A::Diagonal, b::AbstractSparseVector) + if size(A, 2) != length(b) + throw( + DimensionMismatch(lazy"The dimension of the matrix A $(size(A)) and of the vector b $(length(b))") + ) + end + T = promote_eltype(A, b) + res = similar(b, T) + nzind_b = nonzeroinds(b) + nzval_b = nonzeros(b) + nzval_res = nonzeros(res) + for idx in eachindex(nzind_b) + nzval_res[idx] = A.diag[nzind_b[idx]] * nzval_b[idx] + end + return res +end + # Sparse matrix multiplication as described in [Gustavson, 1978]: # http://dl.acm.org/citation.cfm?id=355796 diff --git a/test/linalg.jl b/test/linalg.jl index 45d42d9f..295eba61 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -673,6 +673,21 @@ end end end +@testset "diagonal - sparse vector mutliplication" begin + for _ in 1:10 + b = spzeros(10) + b[1:3] .= 1:3 + A = Diagonal(randn(10)) + @test norm(A * b - A * Vector(b)) <= 10eps() + @test norm(A * b - Array(A) * b) <= 10eps() + Ac = Diagonal(randn(Complex{Float64}, 10)) + @test norm(Ac * b - Ac * Vector(b)) <= 10eps() + @test norm(Ac * b - Array(Ac) * b) <= 10eps() + @test_throws DimensionMismatch A * [b; 1] + @test_throws DimensionMismatch A * b[1:end-1] + end +end + @testset "sparse matrix * BitArray" begin A = sprand(5,5,0.3) MA = Array(A)