Skip to content

Commit 8094ded

Browse files
authored
Add/unify broadcast in mul! with Diagonal (#574)
1 parent 0f19033 commit 8094ded

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/host/linalg.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
283283
m′, n′ = size(B, 1), size(B, 2)
284284
n == d || throw(DimensionMismatch("left hand side has $n columns but D is $d by $d"))
285285
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
286-
B .= A .* transpose(dd)
286+
ddT = transpose(dd)
287+
@. B = A * ddT
287288

288289
B
289290
end
@@ -299,7 +300,8 @@ function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
299300
m′, n′ = size(B, 1), size(B, 2)
300301
n == d || throw(DimensionMismatch("left hand side has $n columns but D is $d by $d"))
301302
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
302-
B .= α * A .* transpose(dd) + β * B
303+
ddT = transpose(dd)
304+
@. B = α * A * ddT + β * B
303305

304306
B
305307
end

0 commit comments

Comments
 (0)