@@ -330,3 +330,41 @@ if VERSION ≥ v"1.12-"
330330 LinearAlgebra. copytrito! (B:: Matrix{T} , A:: ROCMatrix{T} , uplo:: AbstractChar ) where {T <: ROCBLASFloat } =
331331 invoke (LinearAlgebra. copytrito!, Tuple{AbstractMatrix, AbstractMatrix, AbstractChar}, B, A, uplo)
332332end
333+
334+ function LinearAlgebra. lmul! (A:: Diagonal{T,<:ROCVector{T}} , B:: ROCMatrix{T} ) where {T<: ROCBLASFloat }
335+ return dgmm! (' L' , B, A. diag, B)
336+ end
337+
338+ function LinearAlgebra. rmul! (A:: ROCMatrix{T} , B:: Diagonal{T,<:ROCVector{T}} ) where {T<: ROCBLASFloat }
339+ return dgmm! (' R' , A, B. diag, A)
340+ end
341+
342+ # eltypes do not match
343+ function LinearAlgebra. lmul! (A:: Diagonal{T,<:ROCVector{T}} , B:: ROCMatrix ) where {T<: ROCBLASFloat }
344+ @. B = A. diag * B
345+ return B
346+ end
347+ function LinearAlgebra. lmul! (A:: Diagonal{Td,<:ROCVector{Td}} , B:: Transpose{Tt, <:ROCMatrix{Tt}} ) where {Td<: ROCBLASFloat , Tt<: ROCBLASFloat }
348+ @. B = A. diag * B
349+ return B
350+ end
351+ function LinearAlgebra. lmul! (A:: Diagonal{Td,<:ROCVector{Td}} , B:: Adjoint{Tt, <:ROCMatrix{Tt}} ) where {Td<: ROCBLASFloat , Tt<: ROCBLASFloat }
352+ @. B = A. diag * B
353+ return B
354+ end
355+ # eltypes do not match
356+ function LinearAlgebra. rmul! (A:: ROCMatrix , B:: Diagonal{T,<:ROCVector{T}} ) where {T<: ROCBLASFloat }
357+ At = transpose (A)
358+ @. At = B. diag * At
359+ return A
360+ end
361+ function LinearAlgebra. rmul! (A:: Transpose{Tt, <:ROCMatrix{Tt}} , B:: Diagonal{Td,<:ROCVector{Td}} ) where {Td<: ROCBLASFloat , Tt<: ROCBLASFloat }
362+ At = parent (A)
363+ @. At = B. diag * At
364+ return transpose (At)
365+ end
366+ function LinearAlgebra. rmul! (A:: Adjoint{Tt, <:ROCMatrix{Tt}} , B:: Diagonal{Td,<:ROCVector{Td}} ) where {Td<: ROCBLASFloat , Tt<: ROCBLASFloat }
367+ At = parent (A)
368+ @. At = adjoint (B. diag) * At
369+ return adjoint (At)
370+ end
0 commit comments