Skip to content

Commit cab0dc6

Browse files
authored
Compile-time check for zero alpha in matmul (#1293)
Currently, we check for `iszero(_add.alpha)` in various matrix multiplication functions, where `_add::MulAddMul`. However, a `MulAddMul(alpha, beta)` stores `isone(alpha)` as a type parameter, and if `alpha` is one, we know at compile-time that `iszero(_add.alpha)` is `false`. Using the type parameter in dispatch would allow us to eliminate the `iszero(alpha)` branches.
1 parent 78e6156 commit cab0dc6

File tree

4 files changed

+25
-18
lines changed

4 files changed

+25
-18
lines changed

src/bidiag.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ function _bibimul!(C, A, B, _add)
613613
# `_modify!` in the following loop will not update the
614614
# off-diagonal elements for non-zero beta.
615615
_rmul_or_fill!(C, _add.beta)
616-
iszero(_add.alpha) && return C
616+
_iszero_alpha(_add) && return C
617617
if n <= 3
618618
# naive multiplication
619619
for I in CartesianIndices(C)
@@ -858,7 +858,7 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
858858
n = size(A,1)
859859
iszero(n) && return C
860860
_rmul_or_fill!(C, _add.beta) # see the same use above
861-
iszero(_add.alpha) && return C
861+
_iszero_alpha(_add) && return C
862862
Al = _diag(A, -1)
863863
Ad = _diag(A, 0)
864864
Au = _diag(A, 1)
@@ -901,7 +901,7 @@ function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
901901
n = size(A,1)
902902
iszero(n) && return C
903903
_rmul_or_fill!(C, _add.beta) # see the same use above
904-
iszero(_add.alpha) && return C
904+
_iszero_alpha(_add) && return C
905905
(; dv, ev) = A
906906
Bd = B.diag
907907
rowshift = A.uplo == 'U' ? -1 : 1
@@ -930,7 +930,7 @@ function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
930930
matmul_size_check(size(C), size(A), size(B))
931931
n = size(A,1)
932932
iszero(n) && return C
933-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
933+
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
934934
Adv, Aev = A.dv, A.ev
935935
Cdv, Cev = C.dv, C.ev
936936
Bd = B.diag
@@ -964,7 +964,7 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
964964
nA = size(A,1)
965965
nB = size(B,2)
966966
(iszero(nA) || iszero(nB)) && return C
967-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
967+
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
968968
if nA == 1
969969
A11 = @inbounds A[1,1]
970970
for i in axes(B, 2)
@@ -1032,7 +1032,7 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
10321032
matmul_size_check(size(C), size(A), size(B))
10331033
n = size(A,1)
10341034
m = size(B,2)
1035-
(iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
1035+
(_iszero_alpha(_add) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
10361036
if m == 1
10371037
B11 = B[1,1]
10381038
return mul!(C, A, B11, _add.alpha, _add.beta)
@@ -1068,7 +1068,7 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
10681068
matmul_size_check(size(C), size(A), size(B))
10691069
m, n = size(A)
10701070
(iszero(m) || iszero(n)) && return C
1071-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
1071+
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
10721072
@inbounds if B.uplo == 'U'
10731073
for j in n:-1:2, i in 1:m
10741074
_modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j))
@@ -1100,7 +1100,7 @@ function _dibimul!(C, A, B, _add)
11001100
iszero(n) && return C
11011101
# ensure that we fill off-band elements in the destination
11021102
_rmul_or_fill!(C, _add.beta)
1103-
iszero(_add.alpha) && return C
1103+
_iszero_alpha(_add) && return C
11041104
if n <= 3
11051105
# For simplicity, use a naive multiplication for small matrices
11061106
# that loops over all elements.
@@ -1144,7 +1144,7 @@ function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
11441144
iszero(n) && return C
11451145
# ensure that we fill off-band elements in the destination
11461146
_rmul_or_fill!(C, _add.beta)
1147-
iszero(_add.alpha) && return C
1147+
_iszero_alpha(_add) && return C
11481148
Ad = A.diag
11491149
Bdv, Bev = B.dv, B.ev
11501150
rowshift = B.uplo == 'U' ? -1 : 1
@@ -1173,7 +1173,7 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11731173
matmul_size_check(size(C), size(A), size(B))
11741174
n = size(A,1)
11751175
n == 0 && return C
1176-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
1176+
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
11771177
Ad = A.diag
11781178
Bdv, Bev = B.dv, B.ev
11791179
Cdv, Cev = C.dv, C.ev

src/generic.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false)
124124
@inline (p::MulAddMul{true, false})(x, y) = x + y * p.beta
125125
@inline (p::MulAddMul{false, false})(x, y) = x * p.alpha + y * p.beta
126126

127+
_iszero_alpha(m::MulAddMul) = iszero(m.alpha)
128+
_iszero_alpha(m::MulAddMul{true}) = false
129+
127130
"""
128131
_modify!(_add::MulAddMul, x, C, idx)
129132

src/triangular.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ end
718718

719719
function _triscale!(A::UpperTriangular, B::UpperTriangular, c::Number, _add)
720720
checksize1(A, B)
721-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
721+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
722722
for j in axes(B.data,2)
723723
for i in firstindex(B.data,1):j
724724
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
@@ -728,7 +728,7 @@ function _triscale!(A::UpperTriangular, B::UpperTriangular, c::Number, _add)
728728
end
729729
function _triscale!(A::UpperTriangular, c::Number, B::UpperTriangular, _add)
730730
checksize1(A, B)
731-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
731+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
732732
for j in axes(B.data,2)
733733
for i in firstindex(B.data,1):j
734734
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
@@ -738,7 +738,7 @@ function _triscale!(A::UpperTriangular, c::Number, B::UpperTriangular, _add)
738738
end
739739
function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Number, _add)
740740
checksize1(A, B)
741-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
741+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
742742
for j in axes(B.data,2)
743743
@inbounds _modify!(_add, c, A, (j,j))
744744
for i in firstindex(B.data,1):(j - 1)
@@ -749,7 +749,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
749749
end
750750
function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriangular, _add)
751751
checksize1(A, B)
752-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
752+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
753753
for j in axes(B.data,2)
754754
@inbounds _modify!(_add, c, A, (j,j))
755755
for i in firstindex(B.data,1):(j - 1)
@@ -760,7 +760,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
760760
end
761761
function _triscale!(A::LowerTriangular, B::LowerTriangular, c::Number, _add)
762762
checksize1(A, B)
763-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
763+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
764764
for j in axes(B.data,2)
765765
for i in j:lastindex(B.data,1)
766766
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
@@ -770,7 +770,7 @@ function _triscale!(A::LowerTriangular, B::LowerTriangular, c::Number, _add)
770770
end
771771
function _triscale!(A::LowerTriangular, c::Number, B::LowerTriangular, _add)
772772
checksize1(A, B)
773-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
773+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
774774
for j in axes(B.data,2)
775775
for i in j:lastindex(B.data,1)
776776
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
@@ -780,7 +780,7 @@ function _triscale!(A::LowerTriangular, c::Number, B::LowerTriangular, _add)
780780
end
781781
function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Number, _add)
782782
checksize1(A, B)
783-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
783+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
784784
for j in axes(B.data,2)
785785
@inbounds _modify!(_add, c, A, (j,j))
786786
for i in (j + 1):lastindex(B.data,1)
@@ -791,7 +791,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
791791
end
792792
function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriangular, _add)
793793
checksize1(A, B)
794-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
794+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
795795
for j in axes(B.data,2)
796796
@inbounds _modify!(_add, c, A, (j,j))
797797
for i in (j + 1):lastindex(B.data,1)

test/bidiag.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,9 @@ end
996996
@test A * D mul!(S, A, D) M * D
997997
@test D * A mul!(S, D, A) D * M
998998
@test mul!(copy(S), D, A, 2, 2) D * M * 2 + S * 2
999+
@test mul!(copy(S), D, A, 0, 2) D * M * 0 + S * 2
9991000
@test mul!(copy(S), A, D, 2, 2) M * D * 2 + S * 2
1001+
@test mul!(copy(S), A, D, 0, 2) M * D * 0 + S * 2
10001002

10011003
A2 = Bidiagonal(dv, zero(ev), uplo)
10021004
M2 = Array(A2)
@@ -1074,10 +1076,12 @@ end
10741076
@test B * v M * v
10751077
@test mul!(similar(v), B, v) M * v
10761078
@test mul!(ones(size(v)), B, v, 2, 3) M * v * 2 .+ 3
1079+
@test mul!(ones(size(v)), B, v, 0, 3) M * v * 0 .+ 3
10771080

10781081
@test B * B M * M
10791082
@test mul!(similar(B, size(B)), B, B) M * M
10801083
@test mul!(ones(size(B)), B, B, 2, 4) M * M * 2 .+ 4
1084+
@test mul!(ones(size(B)), B, B, 0, 4) M * M * 0 .+ 4
10811085

10821086
for m in 0:6
10831087
AL = rand(m,n)

0 commit comments

Comments
 (0)