diff --git a/src/linalg.jl b/src/linalg.jl index 131a21bc..9cf91d29 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -205,6 +205,27 @@ const SparseOrTri{Tv,Ti} = Union{SparseMatrixCSCUnion{Tv,Ti},SparseTriangular{Tv *(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::SparseOrTri) = spmatmul(copy(A), B) *(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(copy(A), copy(B)) +(*)(Da::Diagonal, A::Union{SparseMatrixCSCUnion, AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}}, Db::Diagonal) = Da * (A * Db) +function (*)(Da::Diagonal, A::SparseMatrixCSC, Db::Diagonal) + (size(Da, 2) == size(A,1) && size(A,2) == size(Db,1)) || + throw(DimensionMismatch("incompatible sizes")) + T = promote_op(matprod, eltype(Da), promote_op(matprod, eltype(A), eltype(Db))) + dest = similar(A, T) + vals_dest = nonzeros(dest) + rows = rowvals(A) + vals = nonzeros(A) + da, db = map(parent, (Da, Db)) + for col in axes(A,2) + dbcol = db[col] + for i in nzrange(A, col) + row = rows[i] + val = vals[i] + vals_dest[i] = da[row] * val * dbcol + end + end + dest +end + # Gustavson's matrix multiplication algorithm revisited. # The result rowval vector is already sorted by construction. # The auxiliary Vector{Ti} xb is replaced by a Vector{Bool} of same length. diff --git a/test/linalg.jl b/test/linalg.jl index 45d42d9f..d3f004ca 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -912,4 +912,24 @@ end @test sparse(3I, 4, 5) == sparse(1:4, 1:4, 3, 4, 5) @test sparse(3I, 5, 4) == sparse(1:4, 1:4, 3, 5, 4) end + +@testset "diagonal-sandwiched triple multiplication" begin + S = sprand(4, 6, 0.2) + D1 = Diagonal(axes(S,1)) + D2 = Diagonal(axes(S,2) .+ 4) + A = Array(S) + C = D1 * S * D2 + @test C isa SparseMatrixCSC + @test C ≈ D1 * A * D2 + C = D2 * S' * D1 + @test C isa SparseMatrixCSC + @test C ≈ D2 * A' * D1 + C = D1 * view(S, :, :) * D2 + @test C isa SparseMatrixCSC + @test C ≈ D1 * A * D2 + + @test_throws DimensionMismatch D2 * S * D2 + @test_throws DimensionMismatch D1 * S * D1 +end + end