Skip to content

Commit 062d0a4

Browse files
committed
Support mul!(Diagonal, A, B)
1 parent d794272 commit 062d0a4

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/blas/highlevel.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,12 @@ function LinearAlgebra.rmul!(A::ROCMatrix{T}, B::Diagonal{T,<:ROCVector{T}}) whe
340340
return dgmm!('R', A, B.diag, A)
341341
end
342342

343+
function LinearAlgebra.mul!(C::Diagonal{T, <:ROCVector}, A::Union{<:ROCMatrix{T}, Adjoint{T, <:ROCMatrix}, Transpose{T, <:ROCMatrix}}, B::Union{<:ROCMatrix{T}, Adjoint{T, <:ROCMatrix}, Transpose{T, <:ROCMatrix}}) where {T<:ROCBLASloat}
344+
Cfull = A*B
345+
C.diag .= diag(Cfull)
346+
return C
347+
end
348+
343349
# eltypes do not match
344350
function LinearAlgebra.lmul!(A::Diagonal{T,<:ROCVector{T}}, B::ROCMatrix) where {T<:ROCBLASFloat}
345351
@. B = A.diag * B

test/rocarray/blas.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,9 @@ end
484484
@test testf(
485485
(c, a, b) -> mul!(c, a, Diagonal(b)),
486486
zeros(T, m, m), rand(T, m, m), rand(T, m, m))
487+
@test testf(
488+
(c, a, b) -> mul!(Diagonal(c), a, b),
489+
zeros(T, m), diagm(rand(T, m)), diagm(rand(T, m)))
487490
end
488491
end
489492

0 commit comments

Comments
 (0)