diff --git a/src/bidiag.jl b/src/bidiag.jl index 19a80336..a9a145e6 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -600,6 +600,27 @@ function _diag(A::Bidiagonal, k) end end +""" + _MulAddMul_nonzeroalpha(_add::MulAddMul[, ::Val{false}]) + +Return a new `MulAddMul` with the value of `alpha` potentially set to a literal non-zero +value if permitted by the type (e.g., for `_add.alpha isa Bool`, in which case the `alpha` is +set to `true` in the returned instance). +In other cases, the single-argument call is a no-op and returns `_add` without modifications. + +In addition, if `Val(false)` is provided as the second argument, +`beta` is set to `false` in the returned `MulAddMul` instance. +""" +_MulAddMul_nonzeroalpha(_add::MulAddMul) = _add +function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,A}, ::Val{false}) where {ais1,bis0,A} + MulAddMul{ais1,true,A,Bool}(_add.alpha, false) +end +function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}) where {ais1,bis0} + (; beta) = _add + MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) +end +_MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}, ::Val{false}) where {ais1,bis0} = MulAddMul() + _mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) = _bibimul!(C, A, B, _add) _mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) = @@ -613,36 +634,54 @@ function _bibimul!(C, A, B, _add) # `_modify!` in the following loop will not update the # off-diagonal elements for non-zero beta. _rmul_or_fill!(C, _add.beta) - _iszero_alpha(_add) && return C - if n <= 3 + iszero(_add.alpha) && return C + # beta is unused in _bibimul_nonzeroalpha!, so we set it to false + _add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false)) + _bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha) + C +end +function _bibimul_nonzeroalpha!(C, A, B, _add) + n = size(A,1) + if n == 1 # naive multiplication - for I in CartesianIndices(C) - C[I] += _add(sum(A[I[1], k] * B[k, I[2]] for k in axes(A,2))) - end + @inbounds C[1,1] += _add(A[1,1] * B[1,1]) return C end @inbounds begin # first column of C C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2,1]) C[2,1] += _add(A[2,1]*B[1,1] + A[2,2]*B[2,1]) - C[3,1] += _add(A[3,2]*B[2,1]) + if n >= 3 + C[3,1] += _add(A[3,2]*B[2,1]) + end # second column of C C[1,2] += _add(A[1,1]*B[1,2] + A[1,2]*B[2,2]) - C[2,2] += _add(A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2]) - C[3,2] += _add(A[3,2]*B[2,2] + A[3,3]*B[3,2]) - C[4,2] += _add(A[4,3]*B[3,2]) + C22 = A[2,1]*B[1,2] + A[2,2]*B[2,2] + if n >= 3 + C[2,2] += _add(C22 + A[2,3]*B[3,2]) + C[3,2] += _add(A[3,2]*B[2,2] + A[3,3]*B[3,2]) + if n >= 4 + C[4,2] += _add(A[4,3]*B[3,2]) + end + else + C[2,2] += _add(C22) + end end # inbounds # middle columns __bibimul!(C, A, B, _add) @inbounds begin - C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1]) - C[n-2,n-1] += _add(A[n-2,n-2]*B[n-2,n-1] + A[n-2,n-1]*B[n-1,n-1]) - C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]) - C[n, n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) + if n >= 4 + C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1]) + C[n-2,n-1] += _add(A[n-2,n-2]*B[n-2,n-1] + A[n-2,n-1]*B[n-1,n-1]) + C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]) + C[n, n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) + end # last column of C - C[n-2, n] += _add(A[n-2,n-1]*B[n-1,n]) - C[n-1, n] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1,n]*B[n,n ]) - C[n, n] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) + if n >= 3 + C[n-2, n] += _add(A[n-2,n-1]*B[n-1,n]) + C[n-1, n] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1,n]*B[n,n ]) + C[n, n] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) + end end # inbounds C end @@ -683,9 +722,9 @@ function __bibimul!(C, A, B::Bidiagonal, _add) Al = _diag(A, -1) Ad = _diag(A, 0) Au = _diag(A, 1) - Bd = _diag(B, 0) + Bd = B.dv if B.uplo == 'U' - Bu = _diag(B, 1) + Bu = B.ev @inbounds begin for j in 3:n-2 Aj₋2j₋1 = Au[j-2] @@ -704,7 +743,7 @@ function __bibimul!(C, A, B::Bidiagonal, _add) end end else # B.uplo == 'L' - Bl = _diag(B, -1) + Bl = B.ev @inbounds begin for j in 3:n-2 Aj₋1j = Au[j-1] @@ -730,9 +769,9 @@ function __bibimul!(C, A::Bidiagonal, B, _add) Bl = _diag(B, -1) Bd = _diag(B, 0) Bu = _diag(B, 1) - Ad = _diag(A, 0) + Ad = A.dv if A.uplo == 'U' - Au = _diag(A, 1) + Au = A.ev @inbounds begin for j in 3:n-2 Aj₋2j₋1 = Au[j-2] @@ -752,7 +791,7 @@ function __bibimul!(C, A::Bidiagonal, B, _add) end end else # A.uplo == 'L' - Al = _diag(A, -1) + Al = A.ev @inbounds begin for j in 3:n-2 Aj₋1j₋1 = Ad[j-1] @@ -776,11 +815,11 @@ function __bibimul!(C, A::Bidiagonal, B, _add) end function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) n = size(A,1) - Ad = _diag(A, 0) - Bd = _diag(B, 0) + Ad = A.dv + Bd = B.dv if A.uplo == 'U' && B.uplo == 'U' - Au = _diag(A, 1) - Bu = _diag(B, 1) + Au = A.ev + Bu = B.ev @inbounds begin for j in 3:n-2 Aj₋2j₋1 = Au[j-2] @@ -796,8 +835,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) end end elseif A.uplo == 'U' && B.uplo == 'L' - Au = _diag(A, 1) - Bl = _diag(B, -1) + Au = A.ev + Bl = B.ev @inbounds begin for j in 3:n-2 Aj₋1j = Au[j-1] @@ -813,8 +852,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) end end elseif A.uplo == 'L' && B.uplo == 'U' - Al = _diag(A, -1) - Bu = _diag(B, 1) + Al = A.ev + Bu = B.ev @inbounds begin for j in 3:n-2 Aj₋1j₋1 = Ad[j-1] @@ -830,8 +869,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) end end else # A.uplo == 'L' && B.uplo == 'L' - Al = _diag(A, -1) - Bl = _diag(B, -1) + Al = A.ev + Bl = B.ev @inbounds begin for j in 3:n-2 Ajj = Ad[j] @@ -850,15 +889,20 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) C end -_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = - @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul) require_one_based_indexing(C) matmul_size_check(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C _rmul_or_fill!(C, _add.beta) # see the same use above - _iszero_alpha(_add) && return C + iszero(_add.alpha) && return C + # beta is unused in the _bidimul! call, so we set it to false + _add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false)) + _bidimul!(C, A, B, _add_nonzeroalpha) + C +end +function _bidimul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul) + n = size(A,1) Al = _diag(A, -1) Ad = _diag(A, 0) Au = _diag(A, 1) @@ -894,14 +938,8 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul) end # inbounds C end - -function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul) - require_one_based_indexing(C) - matmul_size_check(size(C), size(A), size(B)) +function _bidimul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul) n = size(A,1) - iszero(n) && return C - _rmul_or_fill!(C, _add.beta) # see the same use above - _iszero_alpha(_add) && return C (; dv, ev) = A Bd = B.diag rowshift = A.uplo == 'U' ? -1 : 1 @@ -930,7 +968,13 @@ function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul) matmul_size_check(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C - _iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta) + iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) + _add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add) + _bidimul!(C, A, B, _add_nonzeroalpha) + C +end +function _bidimul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul) + n = size(A,1) Adv, Aev = A.dv, A.ev Cdv, Cev = C.dv, C.ev Bd = B.diag @@ -965,14 +1009,22 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA nB = size(B,2) (iszero(nA) || iszero(nB)) && return C _iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta) + _add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add) + _mul_bitrisym_left!(C, A, B, _add_nonzeroalpha) + return C +end +function _mul_bitrisym_left!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul) + nA = size(A,1) + nB = size(B,2) if nA == 1 A11 = @inbounds A[1,1] for i in axes(B, 2) @inbounds _modify!(_add, A11 * B[1,i], C, (1,i)) end - return C + else + _mul_bitrisym!(C, A, B, _add) end - _mul_bitrisym!(C, A, B, _add) + return C end function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul) nA = size(A,1) @@ -1033,6 +1085,13 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul) n = size(A,1) m = size(B,2) (_iszero_alpha(_add) || iszero(m)) && return _rmul_or_fill!(C, _add.beta) + _add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add) + _mul_bitrisym_right!(C, A, B, _add_nonzeroalpha) + C +end +function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul) + n = size(A,1) + m = size(B,2) if m == 1 B11 = B[1,1] return mul!(C, A, B11, _add.alpha, _add.beta) @@ -1069,6 +1128,12 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd m, n = size(A) (iszero(m) || iszero(n)) && return C _iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta) + _add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add) + _mul_bitrisym_right!(C, A, B, _add_nonzeroalpha) + C +end +function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul) + m, n = size(A) @inbounds if B.uplo == 'U' for j in n:-1:2, i in 1:m _modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j)) @@ -1101,6 +1166,13 @@ function _dibimul!(C, A, B, _add) # ensure that we fill off-band elements in the destination _rmul_or_fill!(C, _add.beta) _iszero_alpha(_add) && return C + # beta is unused in the _dibimul_nonzeroalpha! call, so we set it to false + _add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false)) + _dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha) + C +end +function _dibimul_nonzeroalpha!(C, A, B, _add) + n = size(A,1) if n <= 3 # For simplicity, use a naive multiplication for small matrices # that loops over all elements. @@ -1137,14 +1209,8 @@ function _dibimul!(C, A, B, _add) end # inbounds C end -function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add) - require_one_based_indexing(C) - matmul_size_check(size(C), size(A), size(B)) +function _dibimul_nonzeroalpha!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add) n = size(A,1) - iszero(n) && return C - # ensure that we fill off-band elements in the destination - _rmul_or_fill!(C, _add.beta) - _iszero_alpha(_add) && return C Ad = A.diag Bdv, Bev = B.dv, B.ev rowshift = B.uplo == 'U' ? -1 : 1 @@ -1174,6 +1240,11 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add) n = size(A,1) n == 0 && return C _iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta) + _add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add) + _dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha) + C +end +function _dibimul_nonzeroalpha!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add) Ad = A.diag Bdv, Bev = B.dv, B.ev Cdv, Cev = C.dv, C.ev