diff --git a/src/diagonal.jl b/src/diagonal.jl index 27ff3b1f..ac76b33d 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -382,6 +382,14 @@ end end return out end +@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Bool, beta::Number) + @inbounds for j in axes(B, 2) + @simd for i in axes(B, 1) + @stable_muladdmul _modify!(MulAddMul(true,beta), D.diag[i] * B[i,j], out, (i,j)) + end + end + return out +end _has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true _has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true _has_matching_zeros(out, A) = false @@ -429,6 +437,15 @@ end end return out end +@inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Bool, beta::Number) + @inbounds for j in axes(A, 2) + dja = @stable_muladdmul MulAddMul(true,false)(D.diag[j]) + @simd for i in axes(A, 1) + @stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j)) + end + end + return out +end function __muldiag_nonzeroalpha!(out, A, D::Diagonal, alpha::Number, beta::Number) __muldiag_nonzeroalpha_right!(out, A, D, alpha, beta)