Skip to content
175 changes: 123 additions & 52 deletions src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,27 @@ function _diag(A::Bidiagonal, k)
end
end

"""
_MulAddMul_nonzeroalpha(_add::MulAddMul[, ::Val{false}])

Return a new `MulAddMul` with the value of `alpha` potentially set to a literal non-zero
value if permitted by the type (e.g., for `_add.alpha isa Bool`, in which case the `alpha` is
set to `true` in the returned instance).
In other cases, the single-argument call is a no-op and returns `_add` without modifications.

In addition, if `Val(false)` is provided as the second argument,
`beta` is set to `false` in the returned `MulAddMul` instance.
"""
_MulAddMul_nonzeroalpha(_add::MulAddMul) = _add
function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,A}, ::Val{false}) where {ais1,bis0,A}
MulAddMul{ais1,true,A,Bool}(_add.alpha, false)
end
function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}) where {ais1,bis0}
(; beta) = _add
MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
end
_MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}, ::Val{false}) where {ais1,bis0} = MulAddMul()

_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
_bibimul!(C, A, B, _add)
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
Expand All @@ -613,36 +634,54 @@ function _bibimul!(C, A, B, _add)
# `_modify!` in the following loop will not update the
# off-diagonal elements for non-zero beta.
_rmul_or_fill!(C, _add.beta)
_iszero_alpha(_add) && return C
if n <= 3
iszero(_add.alpha) && return C
# beta is unused in _bibimul_nonzeroalpha!, so we set it to false
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false))
_bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
C
end
function _bibimul_nonzeroalpha!(C, A, B, _add)
n = size(A,1)
if n == 1
# naive multiplication
for I in CartesianIndices(C)
C[I] += _add(sum(A[I[1], k] * B[k, I[2]] for k in axes(A,2)))
end
@inbounds C[1,1] += _add(A[1,1] * B[1,1])
return C
end
@inbounds begin
# first column of C
C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2,1])
C[2,1] += _add(A[2,1]*B[1,1] + A[2,2]*B[2,1])
C[3,1] += _add(A[3,2]*B[2,1])
if n >= 3
C[3,1] += _add(A[3,2]*B[2,1])
end
# second column of C
C[1,2] += _add(A[1,1]*B[1,2] + A[1,2]*B[2,2])
C[2,2] += _add(A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2])
C[3,2] += _add(A[3,2]*B[2,2] + A[3,3]*B[3,2])
C[4,2] += _add(A[4,3]*B[3,2])
C22 = A[2,1]*B[1,2] + A[2,2]*B[2,2]
if n >= 3
C[2,2] += _add(C22 + A[2,3]*B[3,2])
C[3,2] += _add(A[3,2]*B[2,2] + A[3,3]*B[3,2])
if n >= 4
C[4,2] += _add(A[4,3]*B[3,2])
end
else
C[2,2] += _add(C22)
end
end # inbounds
# middle columns
__bibimul!(C, A, B, _add)
@inbounds begin
C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1])
C[n-2,n-1] += _add(A[n-2,n-2]*B[n-2,n-1] + A[n-2,n-1]*B[n-1,n-1])
C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1])
C[n, n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1])
if n >= 4
C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1])
C[n-2,n-1] += _add(A[n-2,n-2]*B[n-2,n-1] + A[n-2,n-1]*B[n-1,n-1])
C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1])
C[n, n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1])
end
# last column of C
C[n-2, n] += _add(A[n-2,n-1]*B[n-1,n])
C[n-1, n] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1,n]*B[n,n ])
C[n, n] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ])
if n >= 3
C[n-2, n] += _add(A[n-2,n-1]*B[n-1,n])
C[n-1, n] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1,n]*B[n,n ])
C[n, n] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ])
end
end # inbounds
C
end
Expand Down Expand Up @@ -683,9 +722,9 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
Al = _diag(A, -1)
Ad = _diag(A, 0)
Au = _diag(A, 1)
Bd = _diag(B, 0)
Bd = B.dv
if B.uplo == 'U'
Bu = _diag(B, 1)
Bu = B.ev
@inbounds begin
for j in 3:n-2
Aj₋2j₋1 = Au[j-2]
Expand All @@ -704,7 +743,7 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
end
end
else # B.uplo == 'L'
Bl = _diag(B, -1)
Bl = B.ev
@inbounds begin
for j in 3:n-2
Aj₋1j = Au[j-1]
Expand All @@ -730,9 +769,9 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
Bl = _diag(B, -1)
Bd = _diag(B, 0)
Bu = _diag(B, 1)
Ad = _diag(A, 0)
Ad = A.dv
if A.uplo == 'U'
Au = _diag(A, 1)
Au = A.ev
@inbounds begin
for j in 3:n-2
Aj₋2j₋1 = Au[j-2]
Expand All @@ -752,7 +791,7 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
end
end
else # A.uplo == 'L'
Al = _diag(A, -1)
Al = A.ev
@inbounds begin
for j in 3:n-2
Aj₋1j₋1 = Ad[j-1]
Expand All @@ -776,11 +815,11 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
end
function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
n = size(A,1)
Ad = _diag(A, 0)
Bd = _diag(B, 0)
Ad = A.dv
Bd = B.dv
if A.uplo == 'U' && B.uplo == 'U'
Au = _diag(A, 1)
Bu = _diag(B, 1)
Au = A.ev
Bu = B.ev
@inbounds begin
for j in 3:n-2
Aj₋2j₋1 = Au[j-2]
Expand All @@ -796,8 +835,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
end
end
elseif A.uplo == 'U' && B.uplo == 'L'
Au = _diag(A, 1)
Bl = _diag(B, -1)
Au = A.ev
Bl = B.ev
@inbounds begin
for j in 3:n-2
Aj₋1j = Au[j-1]
Expand All @@ -813,8 +852,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
end
end
elseif A.uplo == 'L' && B.uplo == 'U'
Al = _diag(A, -1)
Bu = _diag(B, 1)
Al = A.ev
Bu = B.ev
@inbounds begin
for j in 3:n-2
Aj₋1j₋1 = Ad[j-1]
Expand All @@ -830,8 +869,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
end
end
else # A.uplo == 'L' && B.uplo == 'L'
Al = _diag(A, -1)
Bl = _diag(B, -1)
Al = A.ev
Bl = B.ev
@inbounds begin
for j in 3:n-2
Ajj = Ad[j]
Expand All @@ -850,15 +889,20 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
C
end

_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
require_one_based_indexing(C)
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
_rmul_or_fill!(C, _add.beta) # see the same use above
_iszero_alpha(_add) && return C
iszero(_add.alpha) && return C
# beta is unused in the _bidimul! call, so we set it to false
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false))
_bidimul!(C, A, B, _add_nonzeroalpha)
C
end
function _bidimul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
n = size(A,1)
Al = _diag(A, -1)
Ad = _diag(A, 0)
Au = _diag(A, 1)
Expand Down Expand Up @@ -894,14 +938,8 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
end # inbounds
C
end

function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
require_one_based_indexing(C)
matmul_size_check(size(C), size(A), size(B))
function _bidimul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
n = size(A,1)
iszero(n) && return C
_rmul_or_fill!(C, _add.beta) # see the same use above
_iszero_alpha(_add) && return C
(; dv, ev) = A
Bd = B.diag
rowshift = A.uplo == 'U' ? -1 : 1
Expand Down Expand Up @@ -930,7 +968,13 @@ function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
matmul_size_check(size(C), size(A), size(B))
n = size(A,1)
iszero(n) && return C
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
_bidimul!(C, A, B, _add_nonzeroalpha)
C
end
function _bidimul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
n = size(A,1)
Adv, Aev = A.dv, A.ev
Cdv, Cev = C.dv, C.ev
Bd = B.diag
Expand Down Expand Up @@ -965,14 +1009,22 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
nB = size(B,2)
(iszero(nA) || iszero(nB)) && return C
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
_mul_bitrisym_left!(C, A, B, _add_nonzeroalpha)
return C
end
function _mul_bitrisym_left!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul)
nA = size(A,1)
nB = size(B,2)
if nA == 1
A11 = @inbounds A[1,1]
for i in axes(B, 2)
@inbounds _modify!(_add, A11 * B[1,i], C, (1,i))
end
return C
else
_mul_bitrisym!(C, A, B, _add)
end
_mul_bitrisym!(C, A, B, _add)
return C
end
function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul)
nA = size(A,1)
Expand Down Expand Up @@ -1033,6 +1085,13 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
n = size(A,1)
m = size(B,2)
(_iszero_alpha(_add) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
_mul_bitrisym_right!(C, A, B, _add_nonzeroalpha)
C
end
function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
n = size(A,1)
m = size(B,2)
if m == 1
B11 = B[1,1]
return mul!(C, A, B11, _add.alpha, _add.beta)
Expand Down Expand Up @@ -1069,6 +1128,12 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
m, n = size(A)
(iszero(m) || iszero(n)) && return C
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
_mul_bitrisym_right!(C, A, B, _add_nonzeroalpha)
C
end
function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
m, n = size(A)
@inbounds if B.uplo == 'U'
for j in n:-1:2, i in 1:m
_modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j))
Expand Down Expand Up @@ -1101,6 +1166,13 @@ function _dibimul!(C, A, B, _add)
# ensure that we fill off-band elements in the destination
_rmul_or_fill!(C, _add.beta)
_iszero_alpha(_add) && return C
# beta is unused in the _dibimul_nonzeroalpha! call, so we set it to false
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false))
_dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
C
end
function _dibimul_nonzeroalpha!(C, A, B, _add)
n = size(A,1)
if n <= 3
# For simplicity, use a naive multiplication for small matrices
# that loops over all elements.
Expand Down Expand Up @@ -1137,14 +1209,8 @@ function _dibimul!(C, A, B, _add)
end # inbounds
C
end
function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
require_one_based_indexing(C)
matmul_size_check(size(C), size(A), size(B))
function _dibimul_nonzeroalpha!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
n = size(A,1)
iszero(n) && return C
# ensure that we fill off-band elements in the destination
_rmul_or_fill!(C, _add.beta)
_iszero_alpha(_add) && return C
Ad = A.diag
Bdv, Bev = B.dv, B.ev
rowshift = B.uplo == 'U' ? -1 : 1
Expand Down Expand Up @@ -1174,6 +1240,11 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
n = size(A,1)
n == 0 && return C
_iszero_alpha(_add) && return _rmul_or_fill!(C, _add.beta)
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
_dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
C
end
function _dibimul_nonzeroalpha!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
Ad = A.diag
Bdv, Bev = B.dv, B.ev
Cdv, Cev = C.dv, C.ev
Expand Down