Skip to content

Commit 0bffb61

Browse files
committed
More diag mul methods
1 parent a3af656 commit 0bffb61

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

src/host/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray},
259259
end
260260

261261
function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray},
262-
A::AbstractGPUArray,
263-
B::AbstractGPUArray)
262+
A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}},
263+
B::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}) where {T}
264264
dc = C.diag
265265
d = length(dc)
266266
m, n = size(A, 1), size(A, 2)

test/testsuite/linalg.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@
255255
C = Diagonal(d)
256256
mul!(C, a, b)
257257
@test collect(C) Diagonal(collect(a) * collect(b))
258+
a = transpose(AT(diagm(rand(elty, n))))
259+
b = adjoint(AT(diagm(rand(elty, n))))
260+
C = Diagonal(d)
261+
mul!(C, a, b)
262+
@test collect(C) Diagonal(collect(a) * collect(b))
258263
end
259264
end
260265

0 commit comments

Comments
 (0)