Skip to content

Commit b1bcca1

Browse files
authored
Branch on Bool alpha in diagonal matmul (#1256)
The idea is that if `alpha` is known to be non-zero and a `Bool`, it must be `true`. We may therefore hardcode the value to reduce the branches in `@stable_muladdmul`. TTFX: ```julia julia> using LinearAlgebra julia> D = Diagonal(1:4); A = zeros(4,4); julia> @time A * D; 0.079938 seconds (139.62 k allocations: 6.952 MiB, 99.92% compilation time) # master 0.058087 seconds (126.77 k allocations: 6.290 MiB, 99.88% compilation time) # this PR ``` The TTFX in `D * A` does not change by much, but the allocations go down. ```julia julia> @time D * A; 0.062484 seconds (176.66 k allocations: 8.696 MiB, 99.91% compilation time) # master 0.059009 seconds (133.34 k allocations: 6.572 MiB, 99.91% compilation time) # this PR ```
1 parent e3e9987 commit b1bcca1

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diagonal.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,10 +438,13 @@ function _lmul!(D::Diagonal, A::UpperOrLowerTriangular)
438438
return TriWrapper(P)
439439
end
440440

441+
@propagate_inbounds _modify_nonzeroalpha!(x, out, ind, alpha, beta) = @stable_muladdmul _modify!(MulAddMul(alpha,beta), x, out, ind)
442+
@propagate_inbounds _modify_nonzeroalpha!(x, out, ind, ::Bool, beta) = @stable_muladdmul _modify!(MulAddMul(true,beta), x, out, ind)
443+
441444
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number)
442445
@inbounds for j in axes(B, 2)
443446
@simd for i in axes(B, 1)
444-
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j))
447+
_modify_nonzeroalpha!(D.diag[i] * B[i,j], out, (i,j), alpha, beta)
445448
end
446449
end
447450
return out
@@ -484,9 +487,12 @@ function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, al
484487
return out
485488
end
486489

490+
@inline _djalpha_nonzero(dj, alpha) = @stable_muladdmul MulAddMul(alpha,false)(dj)
491+
@inline _djalpha_nonzero(dj, ::Bool) = dj
492+
487493
@inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Number, beta::Number)
488494
@inbounds for j in axes(A, 2)
489-
dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j])
495+
dja = _djalpha_nonzero(D.diag[j], alpha)
490496
@simd for i in axes(A, 1)
491497
@stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j))
492498
end

0 commit comments

Comments
 (0)