Skip to content

Commit 3faeaa2

Browse files
committed
Branch on Bool alpha in bidiag matmul
1 parent 763f19f commit 3faeaa2

File tree

1 file changed

+34
-17
lines changed

1 file changed

+34
-17
lines changed

src/bidiag.jl

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -602,16 +602,23 @@ _mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
602602
_bibimul!(C, A, B, _add)
603603
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
604604
_bibimul!(C, A, B, _add)
605-
function _bibimul!(C, A, B, _add)
605+
function _bibimul!(C, A, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
606606
require_one_based_indexing(C)
607607
matmul_size_check(size(C), size(A), size(B))
608608
n = size(A,1)
609609
iszero(n) && return C
610610
# We use `_rmul_or_fill!` instead of `_modify!` here since using
611611
# `_modify!` in the following loop will not update the
612612
# off-diagonal elements for non-zero beta.
613-
_rmul_or_fill!(C, _add.beta)
614-
iszero(_add.alpha) && return C
613+
(; alpha, beta) = _add
614+
_rmul_or_fill!(C, beta)
615+
ais1 || (iszero(alpha) && return C)
616+
_add_nonzeroalpha = alpha isa Bool ? MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) : _add
617+
_bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
618+
C
619+
end
620+
function _bibimul_nonzeroalpha!(C, A, B, _add)
621+
n = size(A,1)
615622
if n <= 3
616623
# naive multiplication
617624
for I in CartesianIndices(C)
@@ -848,15 +855,22 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
848855
C
849856
end
850857

851-
_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) =
852-
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
853-
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
858+
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number)
854859
require_one_based_indexing(C)
855860
matmul_size_check(size(C), size(A), size(B))
856861
n = size(A,1)
857862
iszero(n) && return C
858-
_rmul_or_fill!(C, _add.beta) # see the same use above
859-
iszero(_add.alpha) && return C
863+
_rmul_or_fill!(C, beta) # see the same use above
864+
iszero(alpha) && return C
865+
if alpha isa Bool
866+
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(true, beta))
867+
else
868+
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(alpha, beta))
869+
end
870+
C
871+
end
872+
function _bidimul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
873+
n = size(A,1)
860874
Al = _diag(A, -1)
861875
Ad = _diag(A, 0)
862876
Au = _diag(A, 1)
@@ -892,14 +906,8 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
892906
end # inbounds
893907
C
894908
end
895-
896-
function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
897-
require_one_based_indexing(C)
898-
matmul_size_check(size(C), size(A), size(B))
909+
function _bidimul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
899910
n = size(A,1)
900-
iszero(n) && return C
901-
_rmul_or_fill!(C, _add.beta) # see the same use above
902-
iszero(_add.alpha) && return C
903911
(; dv, ev) = A
904912
Bd = B.diag
905913
rowshift = A.uplo == 'U' ? -1 : 1
@@ -924,11 +932,20 @@ function _mul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
924932
C
925933
end
926934

927-
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
935+
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, alpha::Number, beta::Number)
928936
matmul_size_check(size(C), size(A), size(B))
929937
n = size(A,1)
930938
iszero(n) && return C
931-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
939+
iszero(alpha) && return _rmul_or_fill!(C, beta)
940+
if alpha isa Bool
941+
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(true, beta))
942+
else
943+
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(alpha, beta))
944+
end
945+
C
946+
end
947+
function _bidimul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
948+
n = size(A,1)
932949
Adv, Aev = A.dv, A.ev
933950
Cdv, Cev = C.dv, C.ev
934951
Bd = B.diag

0 commit comments

Comments
 (0)