diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 2282528e7..9fad7bbb8 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -258,6 +258,21 @@ function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray}, return C end +function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray}, + A::AbstractGPUArray, + B::AbstractGPUArray) + dc = C.diag + d = length(dc) + m, n = size(A, 1), size(A, 2) + m′, n′ = size(B, 1), size(B, 2) + m == d || throw(DimensionMismatch("left hand side has $m rows but output is $d by $d")) + n′ == d || throw(DimensionMismatch("right hand side has $n′ cols but output is $d by $d")) + C_ = A * B + isdiag(C_) || throw(ErrorException("output matrix must be diagonal")) + dc .= diag(C_) + return C +end + function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, D::Diagonal{<:Any, <:AbstractGPUArray}, A::AbstractGPUVecOrMat) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 5770c85f0..bdcc6cd50 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -250,6 +250,11 @@ A = Diagonal(a) mul!(C, A, B) @test collect(C.diag) ≈ collect(A.diag) .* collect(B.diag) + a = AT(diagm(rand(elty, n))) + b = AT(diagm(rand(elty, n))) + C = Diagonal(d) + mul!(C, a, b) + @test collect(C) ≈ Diagonal(collect(a) * collect(b)) end end