diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 8f871868b..0002fdd54 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -266,8 +266,8 @@ function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray}, end function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray}, - A::AbstractGPUArray, - B::AbstractGPUArray) + A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}, + B::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}) where {T} dc = C.diag d = length(dc) m, n = size(A, 1), size(A, 2) @@ -282,7 +282,7 @@ end function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, D::Diagonal{<:Any, <:AbstractGPUArray}, - A::AbstractGPUVecOrMat) + A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}) where {T} dd = D.diag d = length(dd) m, n = size(A, 1), size(A, 2) @@ -290,15 +290,14 @@ function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d")) (m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′")) @. B = dd * A - B end function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, D::Diagonal{<:Any, <:AbstractGPUArray}, - A::AbstractGPUVecOrMat, + A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}, α::Number, - β::Number) + β::Number) where {T} dd = D.diag d = length(dd) m, n = size(A, 1), size(A, 2) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index f9ac5d924..d77a2bb50 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -238,6 +238,9 @@ mul!(X, D, B) mul!(Y, Diagonal(collect(d)), collect(B)) @test collect(X) ≈ Y + mul!(X, D, adjoint(B)) + mul!(Y, Diagonal(collect(d)), collect(adjoint(B))) + @test collect(X) ≈ Y mul!(X, D, B, α, β) mul!(Y, Diagonal(collect(d)), collect(B), α, β) @test collect(X) ≈ Y @@ -259,6 +262,11 @@ C = Diagonal(d) mul!(C, a, b) @test collect(C) ≈ Diagonal(collect(a) * collect(b)) + a = transpose(AT(diagm(rand(elty, n)))) + b = adjoint(AT(diagm(rand(elty, n)))) + C = Diagonal(d) + mul!(C, a, b) + @test collect(C) ≈ Diagonal(collect(a) * collect(b)) end end