Skip to content

Commit e5ce333

Browse files
committed
Branch on Bool alpha in diagonal matmul
1 parent e53b50c commit e5ce333

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diagonal.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,10 +438,13 @@ function _lmul!(D::Diagonal, A::UpperOrLowerTriangular)
438438
return TriWrapper(P)
439439
end
440440

441+
@inline _modify_nonzeroalpha!(x, out, ind, alpha, beta) = @stable_muladdmul _modify!(MulAddMul(alpha,beta), x, out, ind)
442+
@inline _modify_nonzeroalpha!(x, out, ind, ::Bool, beta) = @stable_muladdmul _modify!(MulAddMul(true,beta), x, out, ind)
443+
441444
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number)
442445
@inbounds for j in axes(B, 2)
443446
@simd for i in axes(B, 1)
444-
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j))
447+
_modify_nonzeroalpha!(D.diag[i] * B[i,j], out, (i,j), alpha, beta)
445448
end
446449
end
447450
return out
@@ -484,9 +487,12 @@ function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, al
484487
return out
485488
end
486489

490+
@inline _djalpha_nonzero(dj, alpha) = @stable_muladdmul MulAddMul(alpha,false)(dj)
491+
@inline _djalpha_nonzero(dj, ::Bool) = dj
492+
487493
@inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Number, beta::Number)
488494
@inbounds for j in axes(A, 2)
489-
dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j])
495+
dja = _djalpha_nonzero(D.diag[j], alpha)
490496
@simd for i in axes(A, 1)
491497
@stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j))
492498
end

0 commit comments

Comments
 (0)