diff --git a/src/blas/highlevel.jl b/src/blas/highlevel.jl index 3391eec96..72dd25b14 100644 --- a/src/blas/highlevel.jl +++ b/src/blas/highlevel.jl @@ -340,6 +340,12 @@ function LinearAlgebra.rmul!(A::ROCMatrix{T}, B::Diagonal{T,<:ROCVector{T}}) whe return dgmm!('R', A, B.diag, A) end +function LinearAlgebra.mul!(C::Diagonal{T, <:ROCVector}, A::Union{<:ROCMatrix{T}, Adjoint{T, <:ROCMatrix}, Transpose{T, <:ROCMatrix}}, B::Union{<:ROCMatrix{T}, Adjoint{T, <:ROCMatrix}, Transpose{T, <:ROCMatrix}}) where {T<:ROCBLASFloat} + Cfull = A*B + C.diag .= diag(Cfull) + return C +end + # eltypes do not match function LinearAlgebra.lmul!(A::Diagonal{T,<:ROCVector{T}}, B::ROCMatrix) where {T<:ROCBLASFloat} @. B = A.diag * B diff --git a/test/rocarray/blas.jl b/test/rocarray/blas.jl index e1b2b5073..0ca9f8bab 100644 --- a/test/rocarray/blas.jl +++ b/test/rocarray/blas.jl @@ -484,6 +484,9 @@ end @test testf( (c, a, b) -> mul!(c, a, Diagonal(b)), zeros(T, m, m), rand(T, m, m), rand(T, m, m)) + @test testf( + (c, a, b) -> mul!(Diagonal(c), a, b), + zeros(T, m), diagm(rand(T, m)), diagm(rand(T, m))) end end