Skip to content

Commit bf7672d

Browse files
committed
Branch on Bool alpha in bidiag matmul
1 parent 925acef commit bf7672d

File tree

1 file changed

+34
-17
lines changed

1 file changed

+34
-17
lines changed

src/bidiag.jl

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -591,16 +591,23 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
591591
_bibimul!(C, A, B, _add)
592592
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
593593
_bibimul!(C, A, B, _add)
594-
function _bibimul!(C, A, B, _add)
594+
function _bibimul!(C, A, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
595595
require_one_based_indexing(C)
596596
matmul_size_check(size(C), size(A), size(B))
597597
n = size(A,1)
598598
iszero(n) && return C
599599
# We use `_rmul_or_fill!` instead of `_modify!` here since using
600600
# `_modify!` in the following loop will not update the
601601
# off-diagonal elements for non-zero beta.
602-
_rmul_or_fill!(C, _add.beta)
603-
iszero(_add.alpha) && return C
602+
(; alpha, beta) = _add
603+
_rmul_or_fill!(C, beta)
604+
ais1 || (iszero(alpha) && return C)
605+
_add_nonzeroalpha = alpha isa Bool ? MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) : _add
606+
_bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
607+
C
608+
end
609+
function _bibimul_nonzeroalpha!(C, A, B, _add)
610+
n = size(A,1)
604611
if n <= 3
605612
# naive multiplication
606613
for I in CartesianIndices(C)
@@ -837,15 +844,22 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
837844
C
838845
end
839846

840-
_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) =
841-
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
842-
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
847+
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number)
843848
require_one_based_indexing(C)
844849
matmul_size_check(size(C), size(A), size(B))
845850
n = size(A,1)
846851
iszero(n) && return C
847-
_rmul_or_fill!(C, _add.beta) # see the same use above
848-
iszero(_add.alpha) && return C
852+
_rmul_or_fill!(C, beta) # see the same use above
853+
iszero(alpha) && return C
854+
if alpha isa Bool
855+
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(true, beta))
856+
else
857+
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(alpha, beta))
858+
end
859+
C
860+
end
861+
function _bidimul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
862+
n = size(A,1)
849863
Al = _diag(A, -1)
850864
Ad = _diag(A, 0)
851865
Au = _diag(A, 1)
@@ -881,14 +895,8 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
881895
end # inbounds
882896
C
883897
end
884-
885-
function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
886-
require_one_based_indexing(C)
887-
matmul_size_check(size(C), size(A), size(B))
898+
function _bidimul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
888899
n = size(A,1)
889-
iszero(n) && return C
890-
_rmul_or_fill!(C, _add.beta) # see the same use above
891-
iszero(_add.alpha) && return C
892900
(; dv, ev) = A
893901
Bd = B.diag
894902
rowshift = A.uplo == 'U' ? -1 : 1
@@ -913,11 +921,20 @@ function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
913921
C
914922
end
915923

916-
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
924+
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, alpha::Number, beta::Number)
917925
matmul_size_check(size(C), size(A), size(B))
918926
n = size(A,1)
919927
iszero(n) && return C
920-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
928+
iszero(alpha) && return _rmul_or_fill!(C, beta)
929+
if alpha isa Bool
930+
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(true, beta))
931+
else
932+
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(alpha, beta))
933+
end
934+
C
935+
end
936+
function _bidimul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
937+
n = size(A,1)
921938
Adv, Aev = A.dv, A.ev
922939
Cdv, Cev = C.dv, C.ev
923940
Bd = B.diag

0 commit comments

Comments
 (0)