Skip to content

Commit 87e91dd

Browse files
authored
Specialized mul for 3 diagonal arguments (#625)
1 parent d5a60f7 commit 87e91dd

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/host/linalg.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,20 @@ function Base.:\(D::Diagonal{<:Any, <:AbstractGPUArray}, B::AbstractGPUVecOrMat)
244244
end
245245
end
246246

247+
function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray},
248+
A::Diagonal{<:Any, <:AbstractGPUArray},
249+
B::Diagonal{<:Any, <:AbstractGPUArray})
250+
dc = C.diag
251+
da = A.diag
252+
db = B.diag
253+
d = length(dc)
254+
length(da) == d || throw(DimensionMismatch("right hand side has $(length(da)) rows but output is $d by $d"))
255+
length(db) == d || throw(DimensionMismatch("left hand side has $(length(db)) rows but output is $d by $d"))
256+
@. dc = da * db
257+
258+
return C
259+
end
260+
247261
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
248262
D::Diagonal{<:Any, <:AbstractGPUArray},
249263
A::AbstractGPUVecOrMat)

test/testsuite/linalg.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,13 @@
243243
mul!(X, B, D, α, β)
244244
mul!(Y, collect(B), Diagonal(collect(d)), α, β)
245245
@test collect(X) Y
246+
a = AT(rand(elty, n))
247+
b = AT(rand(elty, n))
248+
C = Diagonal(d)
249+
B = Diagonal(b)
250+
A = Diagonal(a)
251+
mul!(C, A, B)
252+
@test collect(C.diag) collect(A.diag) .* collect(B.diag)
246253
end
247254
end
248255

0 commit comments

Comments
 (0)