diff --git a/src/bidiag.jl b/src/bidiag.jl index 05c1371c..19a80336 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -613,7 +613,7 @@ function _bibimul!(C, A, B, _add) # `_modify!` in the following loop will not update the # off-diagonal elements for non-zero beta. _rmul_or_fill!(C, _add.beta) - iszero(_add.alpha) && return C + _iszero_alpha(_add) && return C if n <= 3 # naive multiplication for I in CartesianIndices(C) @@ -858,7 +858,7 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul) n = size(A,1) iszero(n) && return C _rmul_or_fill!(C, _add.beta) # see the same use above - iszero(_add.alpha) && return C + _iszero_alpha(_add) && return C Al = _diag(A, -1) Ad = _diag(A, 0) Au = _diag(A, 1) @@ -901,7 +901,7 @@ function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul) n = size(A,1) iszero(n) && return C _rmul_or_fill!(C, _add.beta) # see the same use above - iszero(_add.alpha) && return C + _iszero_alpha(_add) && return C (; dv, ev) = A Bd = B.diag rowshift = A.uplo == 'U' ? -1 : 1 @@ -930,7 +930,7 @@ function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul) matmul_size_check(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C - iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta) Adv, Aev = A.dv, A.ev Cdv, Cev = C.dv, C.ev Bd = B.diag @@ -964,7 +964,7 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA nA = size(A,1) nB = size(B,2) (iszero(nA) || iszero(nB)) && return C - iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta) if nA == 1 A11 = @inbounds A[1,1] for i in axes(B, 2) @@ -1032,7 +1032,7 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul) matmul_size_check(size(C), size(A), size(B)) n = size(A,1) m = size(B,2) - (iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta) + (_iszero_alpha(_add) || iszero(m)) && return _rmul_or_fill!(C, _add.beta) if m == 1 B11 = B[1,1] return mul!(C, A, B11, _add.alpha, _add.beta) @@ -1068,7 +1068,7 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd matmul_size_check(size(C), size(A), size(B)) m, n = size(A) (iszero(m) || iszero(n)) && return C - iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta) @inbounds if B.uplo == 'U' for j in n:-1:2, i in 1:m _modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j)) @@ -1100,7 +1100,7 @@ function _dibimul!(C, A, B, _add) iszero(n) && return C # ensure that we fill off-band elements in the destination _rmul_or_fill!(C, _add.beta) - iszero(_add.alpha) && return C + _iszero_alpha(_add) && return C if n <= 3 # For simplicity, use a naive multiplication for small matrices # that loops over all elements. @@ -1144,7 +1144,7 @@ function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add) iszero(n) && return C # ensure that we fill off-band elements in the destination _rmul_or_fill!(C, _add.beta) - iszero(_add.alpha) && return C + _iszero_alpha(_add) && return C Ad = A.diag Bdv, Bev = B.dv, B.ev rowshift = B.uplo == 'U' ? -1 : 1 @@ -1173,7 +1173,7 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add) matmul_size_check(size(C), size(A), size(B)) n = size(A,1) n == 0 && return C - iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta) Ad = A.diag Bdv, Bev = B.dv, B.ev Cdv, Cev = C.dv, C.ev diff --git a/src/generic.jl b/src/generic.jl index de37f081..5b49b8ca 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -124,6 +124,9 @@ MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false) @inline (p::MulAddMul{true, false})(x, y) = x + y * p.beta @inline (p::MulAddMul{false, false})(x, y) = x * p.alpha + y * p.beta +_iszero_alpha(m::MulAddMul) = iszero(m.alpha) +_iszero_alpha(m::MulAddMul{true}) = false + """ _modify!(_add::MulAddMul, x, C, idx) diff --git a/src/triangular.jl b/src/triangular.jl index ad66f492..ff8d19da 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -692,7 +692,7 @@ end function _triscale!(A::UpperTriangular, B::UpperTriangular, c::Number, _add) checksize1(A, B) - iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) for i in firstindex(B.data,1):j @inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j)) @@ -702,7 +702,7 @@ function _triscale!(A::UpperTriangular, B::UpperTriangular, c::Number, _add) end function _triscale!(A::UpperTriangular, c::Number, B::UpperTriangular, _add) checksize1(A, B) - iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) for i in firstindex(B.data,1):j @inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j)) @@ -712,7 +712,7 @@ function _triscale!(A::UpperTriangular, c::Number, B::UpperTriangular, _add) end function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Number, _add) checksize1(A, B) - iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) @inbounds _modify!(_add, c, A, (j,j)) for i in firstindex(B.data,1):(j - 1) @@ -723,7 +723,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu end function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriangular, _add) checksize1(A, B) - iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) @inbounds _modify!(_add, c, A, (j,j)) for i in firstindex(B.data,1):(j - 1) @@ -734,7 +734,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang end function _triscale!(A::LowerTriangular, B::LowerTriangular, c::Number, _add) checksize1(A, B) - iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) for i in j:lastindex(B.data,1) @inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j)) @@ -744,7 +744,7 @@ function _triscale!(A::LowerTriangular, B::LowerTriangular, c::Number, _add) end function _triscale!(A::LowerTriangular, c::Number, B::LowerTriangular, _add) checksize1(A, B) - iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) for i in j:lastindex(B.data,1) @inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j)) @@ -754,7 +754,7 @@ function _triscale!(A::LowerTriangular, c::Number, B::LowerTriangular, _add) end function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Number, _add) checksize1(A, B) - iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) @inbounds _modify!(_add, c, A, (j,j)) for i in (j + 1):lastindex(B.data,1) @@ -765,7 +765,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu end function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriangular, _add) checksize1(A, B) - iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta) + _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) @inbounds _modify!(_add, c, A, (j,j)) for i in (j + 1):lastindex(B.data,1) diff --git a/test/bidiag.jl b/test/bidiag.jl index e111433d..2488cd3f 100644 --- a/test/bidiag.jl +++ b/test/bidiag.jl @@ -996,7 +996,9 @@ end @test A * D ≈ mul!(S, A, D) ≈ M * D @test D * A ≈ mul!(S, D, A) ≈ D * M @test mul!(copy(S), D, A, 2, 2) ≈ D * M * 2 + S * 2 + @test mul!(copy(S), D, A, 0, 2) ≈ D * M * 0 + S * 2 @test mul!(copy(S), A, D, 2, 2) ≈ M * D * 2 + S * 2 + @test mul!(copy(S), A, D, 0, 2) ≈ M * D * 0 + S * 2 A2 = Bidiagonal(dv, zero(ev), uplo) M2 = Array(A2) @@ -1074,10 +1076,12 @@ end @test B * v ≈ M * v @test mul!(similar(v), B, v) ≈ M * v @test mul!(ones(size(v)), B, v, 2, 3) ≈ M * v * 2 .+ 3 + @test mul!(ones(size(v)), B, v, 0, 3) ≈ M * v * 0 .+ 3 @test B * B ≈ M * M @test mul!(similar(B, size(B)), B, B) ≈ M * M @test mul!(ones(size(B)), B, B, 2, 4) ≈ M * M * 2 .+ 4 + @test mul!(ones(size(B)), B, B, 0, 4) ≈ M * M * 0 .+ 4 for m in 0:6 AL = rand(m,n)