From 59b33b1f06369e5d1f9a44a61cfbfb99571e5a75 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 8 Apr 2025 13:29:27 +0530 Subject: [PATCH] Fewer `MulAddMul` branches in `Diagonal`-triangular mul --- src/diagonal.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diagonal.jl b/src/diagonal.jl index 287b5628..b56ebea6 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -470,19 +470,19 @@ function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, al for j in axes(B, 2) # store the diagonal separately for unit triangular matrices if isunit - @inbounds @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[j] * B[j,j], out, (j,j)) + @inbounds _modify_nonzeroalpha!(D.diag[j] * B[j,j], out, (j,j), alpha, beta) end # The indices of out corresponding to the stored indices of B rowrange = _rowrange_tri_stored(B, j) @inbounds @simd for i in rowrange - @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) + _modify_nonzeroalpha!(D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j), alpha, beta) end # Fill the indices of out corresponding to the zeros of B # we only fill these if out and B don't have matching zeros if !_has_matching_zeros(out, B) rowrange = _rowrange_tri_zeros(B, j) @inbounds @simd for i in rowrange - @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) + _modify_nonzeroalpha!(D.diag[i] * B[i,j], out, (i,j), alpha, beta) end end end @@ -511,7 +511,7 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, al # we may directly read and write from the parents out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A) for j in axes(A, 2) - dja = @stable_muladdmul MulAddMul(alpha,false)(@inbounds D.diag[j]) + dja = @inbounds _djalpha_nonzero(D.diag[j], alpha) # store the diagonal separately for unit triangular matrices if isunit # since alpha is multiplied to the diagonal element of D, @@ -547,7 +547,7 @@ end d2 = D2.diag outd = out.diag @inbounds @simd for i in eachindex(d1, d2, outd) - @stable_muladdmul _modify!(MulAddMul(alpha,beta), d1[i] * d2[i], outd, i) + _modify_nonzeroalpha!(d1[i] * d2[i], outd, i, alpha, beta) end return out end