Skip to content

Commit d444eae

Browse files
committed
Even more
1 parent 0bffb61 commit d444eae

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/host/linalg.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,23 +275,22 @@ end
275275

276276
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
277277
D::Diagonal{<:Any, <:AbstractGPUArray},
278-
A::AbstractGPUVecOrMat)
278+
A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}) where {T}
279279
dd = D.diag
280280
d = length(dd)
281281
m, n = size(A, 1), size(A, 2)
282282
m′, n′ = size(B, 1), size(B, 2)
283283
m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d"))
284284
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
285285
@. B = dd * A
286-
287286
B
288287
end
289288

290289
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
291290
D::Diagonal{<:Any, <:AbstractGPUArray},
292-
A::AbstractGPUVecOrMat,
291+
A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}},
293292
α::Number,
294-
β::Number)
293+
β::Number) where {T}
295294
dd = D.diag
296295
d = length(dd)
297296
m, n = size(A, 1), size(A, 2)

test/testsuite/linalg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@
234234
mul!(X, D, B)
235235
mul!(Y, Diagonal(collect(d)), collect(B))
236236
@test collect(X) Y
237+
mul!(X, D, adjoint(B))
238+
mul!(Y, Diagonal(collect(d)), collect(adjoint(B)))
239+
@test collect(X) Y
237240
mul!(X, D, B, α, β)
238241
mul!(Y, Diagonal(collect(d)), collect(B), α, β)
239242
@test collect(X) Y

0 commit comments

Comments
 (0)