@@ -275,23 +275,22 @@ end
275275
276276function LinearAlgebra. mul!(B:: AbstractGPUVecOrMat ,
277277 D:: Diagonal{<:Any, <:AbstractGPUArray} ,
278- A:: AbstractGPUVecOrMat )
278+ A:: Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}} ) where {T}
279279 dd = D. diag
280280 d = length(dd)
281281 m, n = size(A, 1 ), size(A, 2 )
282282 m′, n′ = size(B, 1 ), size(B, 2 )
283283 m == d || throw(DimensionMismatch(" right hand side has $m rows but D is $d by $d " ))
284284 (m, n) == (m′, n′) || throw(DimensionMismatch(" expect output to be $m by $n , but got $m′ by $n′ " ))
285285 @. B = dd * A
286-
287286 B
288287end
289288
290289function LinearAlgebra. mul!(B:: AbstractGPUVecOrMat ,
291290 D:: Diagonal{<:Any, <:AbstractGPUArray} ,
292- A:: AbstractGPUVecOrMat ,
291+ A:: Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}} ,
293292 α:: Number ,
294- β:: Number )
293+ β:: Number ) where {T}
295294 dd = D. diag
296295 d = length(dd)
297296 m, n = size(A, 1 ), size(A, 2 )
0 commit comments