Skip to content

Commit ad084ee

Browse files
Specialize mul! for mat * diag (#425)
1 parent 26c5191 commit ad084ee

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

src/host/linalg.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,37 @@ else
198198
m′, n′ = size(B, 1), size(B, 2)
199199
m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d"))
200200
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
201-
@. B = α * dd* A + β * B
201+
@. B = α * dd * A + β * B
202+
203+
B
204+
end
205+
206+
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
207+
A::AbstractGPUVecOrMat,
208+
D::Diagonal{<:Any, <:AbstractGPUArray})
209+
dd = D.diag
210+
d = length(dd)
211+
m, n = size(A, 1), size(A, 2)
212+
m′, n′ = size(B, 1), size(B, 2)
213+
n == d || throw(DimensionMismatch("left hand side has $n columns but D is $d by $d"))
214+
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
215+
@. B' = dd * A'
216+
217+
B
218+
end
219+
220+
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
221+
A::AbstractGPUVecOrMat,
222+
D::Diagonal{<:Any, <:AbstractGPUArray},
223+
α::Number,
224+
β::Number)
225+
dd = D.diag
226+
d = length(dd)
227+
m, n = size(A, 1), size(A, 2)
228+
m′, n′ = size(B, 1), size(B, 2)
229+
n == d || throw(DimensionMismatch("left hand side has $n columns but D is $d by $d"))
230+
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
231+
@. B' = α * dd * A' + β * B'
202232

203233
B
204234
end

test/testsuite/linalg.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@
168168
mul!(X, D, B, α, β)
169169
mul!(Y, Diagonal(collect(d)), collect(B), α, β)
170170
@test collect(X) Y
171+
mul!(X, B, D)
172+
mul!(Y, collect(B), Diagonal(collect(d)))
173+
@test collect(X) Y
174+
mul!(X, B, D, α, β)
175+
mul!(Y, collect(B), Diagonal(collect(d)), α, β)
176+
@test collect(X) Y
171177
end
172178

173179
@testset "ldiv! + Diagonal" begin

0 commit comments

Comments
 (0)