Skip to content

Commit 5940765

Browse files
committed
lmul/rmul support for Diagonals
1 parent 9ca7a1d commit 5940765

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/host/linalg.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,7 @@ function generic_rmul!(X::AbstractArray, s::Number)
683683
end
684684

685685
LinearAlgebra.rmul!(A::AbstractGPUArray, b::Number) = generic_rmul!(A, b)
686+
LinearAlgebra.rmul!(A::Diagonal{T, <:AbstractGPUArray}, b::Number) where {T} = A .* b
686687

687688
function generic_lmul!(s::Number, X::AbstractArray)
688689
@kernel function lmul_kernel!(X, s)
@@ -694,6 +695,7 @@ function generic_lmul!(s::Number, X::AbstractArray)
694695
end
695696

696697
LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
698+
LinearAlgebra.lmul!(a::Number, B::Diagonal{T, <:AbstractGPUArray}) where {T} = a .* B
697699

698700

699701
## permutedims

test/testsuite/linalg.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,21 @@ end
437437
A_empty = randn(Float32, 0, 0)
438438
@test compare(f, AT, A_empty, d)
439439
end
440+
441+
@testset "rmul!/lmul! with diagonal and number" begin
442+
n = 32
443+
h_d = rand(Float32, n)
444+
h_D = Diagonal(h_d)
445+
d = AT(h_d)
446+
D = Diagonal(d)
447+
a = rand(Float32)
448+
rmul!(D, a)
449+
rmul!(h_D, a)
450+
@test collect(D) h_D
451+
lmul!(a, D)
452+
lmul!(a, h_D)
453+
@test collect(D) h_D
454+
end
440455
end
441456

442457
@testsuite "linalg/mul!/vector-matrix" (AT, eltypes)->begin

0 commit comments

Comments
 (0)