diff --git a/src/diagonal.jl b/src/diagonal.jl index 878eaad7..7311f6e1 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -438,10 +438,13 @@ function _lmul!(D::Diagonal, A::UpperOrLowerTriangular) return TriWrapper(P) end +@propagate_inbounds _modify_nonzeroalpha!(x, out, ind, alpha, beta) = @stable_muladdmul _modify!(MulAddMul(alpha,beta), x, out, ind) +@propagate_inbounds _modify_nonzeroalpha!(x, out, ind, ::Bool, beta) = @stable_muladdmul _modify!(MulAddMul(true,beta), x, out, ind) + @inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number) @inbounds for j in axes(B, 2) @simd for i in axes(B, 1) - @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) + _modify_nonzeroalpha!(D.diag[i] * B[i,j], out, (i,j), alpha, beta) end end return out @@ -484,9 +487,12 @@ function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, al return out end +@inline _djalpha_nonzero(dj, alpha) = @stable_muladdmul MulAddMul(alpha,false)(dj) +@inline _djalpha_nonzero(dj, ::Bool) = dj + @inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Number, beta::Number) @inbounds for j in axes(A, 2) - dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j]) + dja = _djalpha_nonzero(D.diag[j], alpha) @simd for i in axes(A, 1) @stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j)) end