Skip to content

Commit df8e531

Browse files
committed
Compile-time check for zero alpha in matmul
1 parent 61e444d commit df8e531

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

src/bidiag.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ function _bibimul!(C, A, B, _add)
613613
# `_modify!` in the following loop will not update the
614614
# off-diagonal elements for non-zero beta.
615615
_rmul_or_fill!(C, _add.beta)
616-
iszero(_add.alpha) && return C
616+
_iszero_alpha(_add) && return C
617617
if n <= 3
618618
# naive multiplication
619619
for I in CartesianIndices(C)
@@ -858,7 +858,7 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
858858
n = size(A,1)
859859
iszero(n) && return C
860860
_rmul_or_fill!(C, _add.beta) # see the same use above
861-
iszero(_add.alpha) && return C
861+
_iszero_alpha(_add) && return C
862862
Al = _diag(A, -1)
863863
Ad = _diag(A, 0)
864864
Au = _diag(A, 1)
@@ -901,7 +901,7 @@ function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
901901
n = size(A,1)
902902
iszero(n) && return C
903903
_rmul_or_fill!(C, _add.beta) # see the same use above
904-
iszero(_add.alpha) && return C
904+
_iszero_alpha(_add) && return C
905905
(; dv, ev) = A
906906
Bd = B.diag
907907
rowshift = A.uplo == 'U' ? -1 : 1
@@ -930,7 +930,7 @@ function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
930930
matmul_size_check(size(C), size(A), size(B))
931931
n = size(A,1)
932932
iszero(n) && return C
933-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
933+
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
934934
Adv, Aev = A.dv, A.ev
935935
Cdv, Cev = C.dv, C.ev
936936
Bd = B.diag
@@ -964,7 +964,7 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
964964
nA = size(A,1)
965965
nB = size(B,2)
966966
(iszero(nA) || iszero(nB)) && return C
967-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
967+
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
968968
if nA == 1
969969
A11 = @inbounds A[1,1]
970970
for i in axes(B, 2)
@@ -1032,7 +1032,7 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
10321032
matmul_size_check(size(C), size(A), size(B))
10331033
n = size(A,1)
10341034
m = size(B,2)
1035-
(iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
1035+
(_iszero_alpha(_add) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
10361036
if m == 1
10371037
B11 = B[1,1]
10381038
return mul!(C, A, B11, _add.alpha, _add.beta)
@@ -1068,7 +1068,7 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
10681068
matmul_size_check(size(C), size(A), size(B))
10691069
m, n = size(A)
10701070
(iszero(m) || iszero(n)) && return C
1071-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
1071+
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
10721072
@inbounds if B.uplo == 'U'
10731073
for j in n:-1:2, i in 1:m
10741074
_modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j))
@@ -1100,7 +1100,7 @@ function _dibimul!(C, A, B, _add)
11001100
iszero(n) && return C
11011101
# ensure that we fill off-band elements in the destination
11021102
_rmul_or_fill!(C, _add.beta)
1103-
iszero(_add.alpha) && return C
1103+
_iszero_alpha(_add) && return C
11041104
if n <= 3
11051105
# For simplicity, use a naive multiplication for small matrices
11061106
# that loops over all elements.
@@ -1144,7 +1144,7 @@ function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
11441144
iszero(n) && return C
11451145
# ensure that we fill off-band elements in the destination
11461146
_rmul_or_fill!(C, _add.beta)
1147-
iszero(_add.alpha) && return C
1147+
_iszero_alpha(_add) && return C
11481148
Ad = A.diag
11491149
Bdv, Bev = B.dv, B.ev
11501150
rowshift = B.uplo == 'U' ? -1 : 1
@@ -1173,7 +1173,7 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11731173
matmul_size_check(size(C), size(A), size(B))
11741174
n = size(A,1)
11751175
n == 0 && return C
1176-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
1176+
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
11771177
Ad = A.diag
11781178
Bdv, Bev = B.dv, B.ev
11791179
Cdv, Cev = C.dv, C.ev

src/generic.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false)
124124
@inline (p::MulAddMul{true, false})(x, y) = x + y * p.beta
125125
@inline (p::MulAddMul{false, false})(x, y) = x * p.alpha + y * p.beta
126126

127+
_iszero_alpha(m::MulAddMul) = iszero(M.alpha)
128+
_iszero_alpha(m::MulAddMul{true}) = false
129+
127130
"""
128131
_modify!(_add::MulAddMul, x, C, idx)
129132

src/triangular.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ end
692692

693693
function _triscale!(A::UpperTriangular, B::UpperTriangular, c::Number, _add)
694694
checksize1(A, B)
695-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
695+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
696696
for j in axes(B.data,2)
697697
for i in firstindex(B.data,1):j
698698
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
@@ -702,7 +702,7 @@ function _triscale!(A::UpperTriangular, B::UpperTriangular, c::Number, _add)
702702
end
703703
function _triscale!(A::UpperTriangular, c::Number, B::UpperTriangular, _add)
704704
checksize1(A, B)
705-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
705+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
706706
for j in axes(B.data,2)
707707
for i in firstindex(B.data,1):j
708708
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
@@ -712,7 +712,7 @@ function _triscale!(A::UpperTriangular, c::Number, B::UpperTriangular, _add)
712712
end
713713
function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Number, _add)
714714
checksize1(A, B)
715-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
715+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
716716
for j in axes(B.data,2)
717717
@inbounds _modify!(_add, c, A, (j,j))
718718
for i in firstindex(B.data,1):(j - 1)
@@ -723,7 +723,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
723723
end
724724
function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriangular, _add)
725725
checksize1(A, B)
726-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
726+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
727727
for j in axes(B.data,2)
728728
@inbounds _modify!(_add, c, A, (j,j))
729729
for i in firstindex(B.data,1):(j - 1)
@@ -734,7 +734,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
734734
end
735735
function _triscale!(A::LowerTriangular, B::LowerTriangular, c::Number, _add)
736736
checksize1(A, B)
737-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
737+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
738738
for j in axes(B.data,2)
739739
for i in j:lastindex(B.data,1)
740740
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
@@ -744,7 +744,7 @@ function _triscale!(A::LowerTriangular, B::LowerTriangular, c::Number, _add)
744744
end
745745
function _triscale!(A::LowerTriangular, c::Number, B::LowerTriangular, _add)
746746
checksize1(A, B)
747-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
747+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
748748
for j in axes(B.data,2)
749749
for i in j:lastindex(B.data,1)
750750
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
@@ -754,7 +754,7 @@ function _triscale!(A::LowerTriangular, c::Number, B::LowerTriangular, _add)
754754
end
755755
function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Number, _add)
756756
checksize1(A, B)
757-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
757+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
758758
for j in axes(B.data,2)
759759
@inbounds _modify!(_add, c, A, (j,j))
760760
for i in (j + 1):lastindex(B.data,1)
@@ -765,7 +765,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
765765
end
766766
function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriangular, _add)
767767
checksize1(A, B)
768-
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
768+
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
769769
for j in axes(B.data,2)
770770
@inbounds _modify!(_add, c, A, (j,j))
771771
for i in (j + 1):lastindex(B.data,1)

0 commit comments

Comments
 (0)