Skip to content

Commit 3cb8023

Browse files
committed
Branch in diag-bidiag mul
1 parent 3faeaa2 commit 3cb8023

File tree

1 file changed

+54
-29
lines changed

1 file changed

+54
-29
lines changed

src/bidiag.jl

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -598,22 +598,27 @@ function _diag(A::Bidiagonal, k)
598598
end
599599
end
600600

601+
_MulAddMul_nonzeroalpha(_add::MulAddMul) = _add
602+
function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}) where {ais1,bis0}
603+
(; beta) = _add
604+
MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
605+
end
606+
601607
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
602608
_bibimul!(C, A, B, _add)
603609
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
604610
_bibimul!(C, A, B, _add)
605-
function _bibimul!(C, A, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
611+
function _bibimul!(C, A, B, _add::MulAddMul)
606612
require_one_based_indexing(C)
607613
matmul_size_check(size(C), size(A), size(B))
608614
n = size(A,1)
609615
iszero(n) && return C
610616
# We use `_rmul_or_fill!` instead of `_modify!` here since using
611617
# `_modify!` in the following loop will not update the
612618
# off-diagonal elements for non-zero beta.
613-
(; alpha, beta) = _add
614-
_rmul_or_fill!(C, beta)
615-
ais1 || (iszero(alpha) && return C)
616-
_add_nonzeroalpha = alpha isa Bool ? MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) : _add
619+
_rmul_or_fill!(C, _add.beta)
620+
iszero(_add.alpha) && return C
621+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
617622
_bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
618623
C
619624
end
@@ -855,18 +860,15 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
855860
C
856861
end
857862

858-
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number)
863+
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
859864
require_one_based_indexing(C)
860865
matmul_size_check(size(C), size(A), size(B))
861866
n = size(A,1)
862867
iszero(n) && return C
863-
_rmul_or_fill!(C, beta) # see the same use above
864-
iszero(alpha) && return C
865-
if alpha isa Bool
866-
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(true, beta))
867-
else
868-
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(alpha, beta))
869-
end
868+
_rmul_or_fill!(C, _add.beta) # see the same use above
869+
iszero(_add.alpha) && return C
870+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
871+
_bidimul!(C, A, B, _add_nonzeroalpha)
870872
C
871873
end
872874
function _bidimul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
@@ -932,16 +934,13 @@ function _bidimul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMu
932934
C
933935
end
934936

935-
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, alpha::Number, beta::Number)
937+
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
936938
matmul_size_check(size(C), size(A), size(B))
937939
n = size(A,1)
938940
iszero(n) && return C
939-
iszero(alpha) && return _rmul_or_fill!(C, beta)
940-
if alpha isa Bool
941-
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(true, beta))
942-
else
943-
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(alpha, beta))
944-
end
941+
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
942+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
943+
_bidimul!(C, A, B, _add_nonzeroalpha)
945944
C
946945
end
947946
function _bidimul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
@@ -980,14 +979,22 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
980979
nB = size(B,2)
981980
(iszero(nA) || iszero(nB)) && return C
982981
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
982+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
983+
_mul_bitrisym_left!(C, A, B, _add_nonzeroalpha)
984+
return C
985+
end
986+
function _mul_bitrisym_left!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul)
987+
nA = size(A,1)
988+
nB = size(B,2)
983989
if nA == 1
984990
A11 = @inbounds A[1,1]
985991
for i in axes(B, 2)
986992
@inbounds _modify!(_add, A11 * B[1,i], C, (1,i))
987993
end
988-
return C
994+
else
995+
_mul_bitrisym!(C, A, B, _add)
989996
end
990-
_mul_bitrisym!(C, A, B, _add)
997+
return C
991998
end
992999
function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul)
9931000
nA = size(A,1)
@@ -1048,6 +1055,13 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
10481055
n = size(A,1)
10491056
m = size(B,2)
10501057
(iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
1058+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1059+
_mul_bitrisym_right!(C, A, B, _add_nonzeroalpha)
1060+
C
1061+
end
1062+
function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
1063+
n = size(A,1)
1064+
m = size(B,2)
10511065
if m == 1
10521066
B11 = B[1,1]
10531067
return mul!(C, A, B11, _add.alpha, _add.beta)
@@ -1084,6 +1098,12 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
10841098
m, n = size(A)
10851099
(iszero(m) || iszero(n)) && return C
10861100
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
1101+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1102+
_mul_bitrisym_right!(C, A, B, _add_nonzeroalpha)
1103+
C
1104+
end
1105+
function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
1106+
m, n = size(A)
10871107
@inbounds if B.uplo == 'U'
10881108
for j in n:-1:2, i in 1:m
10891109
_modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j))
@@ -1116,6 +1136,12 @@ function _dibimul!(C, A, B, _add)
11161136
# ensure that we fill off-band elements in the destination
11171137
_rmul_or_fill!(C, _add.beta)
11181138
iszero(_add.alpha) && return C
1139+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1140+
_dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
1141+
C
1142+
end
1143+
function _dibimul_nonzeroalpha!(C, A, B, _add)
1144+
n = size(A,1)
11191145
if n <= 3
11201146
# For simplicity, use a naive multiplication for small matrices
11211147
# that loops over all elements.
@@ -1152,14 +1178,8 @@ function _dibimul!(C, A, B, _add)
11521178
end # inbounds
11531179
C
11541180
end
1155-
function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
1156-
require_one_based_indexing(C)
1157-
matmul_size_check(size(C), size(A), size(B))
1181+
function _dibimul_nonzeroalpha!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
11581182
n = size(A,1)
1159-
iszero(n) && return C
1160-
# ensure that we fill off-band elements in the destination
1161-
_rmul_or_fill!(C, _add.beta)
1162-
iszero(_add.alpha) && return C
11631183
Ad = A.diag
11641184
Bdv, Bev = B.dv, B.ev
11651185
rowshift = B.uplo == 'U' ? -1 : 1
@@ -1189,6 +1209,11 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11891209
n = size(A,1)
11901210
n == 0 && return C
11911211
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
1212+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1213+
_dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
1214+
C
1215+
end
1216+
function _dibimul_nonzeroalpha!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11921217
Ad = A.diag
11931218
Bdv, Bev = B.dv, B.ev
11941219
Cdv, Cev = C.dv, C.ev

0 commit comments

Comments
 (0)