Skip to content

Commit 907be45

Browse files
committed
Branch in diag-bidiag mul
1 parent bf7672d commit 907be45

File tree

1 file changed

+58
-30
lines changed

1 file changed

+58
-30
lines changed

src/bidiag.jl

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -587,22 +587,27 @@ function _diag(A::Bidiagonal, k)
587587
end
588588
end
589589

590+
_MulAddMul_nonzeroalpha(_add::MulAddMul) = _add
591+
function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}) where {ais1,bis0}
592+
(; beta) = _add
593+
MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
594+
end
595+
590596
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
591597
_bibimul!(C, A, B, _add)
592598
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
593599
_bibimul!(C, A, B, _add)
594-
function _bibimul!(C, A, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
600+
function _bibimul!(C, A, B, _add::MulAddMul)
595601
require_one_based_indexing(C)
596602
matmul_size_check(size(C), size(A), size(B))
597603
n = size(A,1)
598604
iszero(n) && return C
599605
# We use `_rmul_or_fill!` instead of `_modify!` here since using
600606
# `_modify!` in the following loop will not update the
601607
# off-diagonal elements for non-zero beta.
602-
(; alpha, beta) = _add
603-
_rmul_or_fill!(C, beta)
604-
ais1 || (iszero(alpha) && return C)
605-
_add_nonzeroalpha = alpha isa Bool ? MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) : _add
608+
_rmul_or_fill!(C, _add.beta)
609+
iszero(_add.alpha) && return C
610+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
606611
_bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
607612
C
608613
end
@@ -844,18 +849,15 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
844849
C
845850
end
846851

847-
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number)
852+
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
848853
require_one_based_indexing(C)
849854
matmul_size_check(size(C), size(A), size(B))
850855
n = size(A,1)
851856
iszero(n) && return C
852-
_rmul_or_fill!(C, beta) # see the same use above
853-
iszero(alpha) && return C
854-
if alpha isa Bool
855-
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(true, beta))
856-
else
857-
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(alpha, beta))
858-
end
857+
_rmul_or_fill!(C, _add.beta) # see the same use above
858+
iszero(_add.alpha) && return C
859+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
860+
_bidimul!(C, A, B, _add_nonzeroalpha)
859861
C
860862
end
861863
function _bidimul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
@@ -921,16 +923,13 @@ function _bidimul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMu
921923
C
922924
end
923925

924-
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, alpha::Number, beta::Number)
926+
function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
925927
matmul_size_check(size(C), size(A), size(B))
926928
n = size(A,1)
927929
iszero(n) && return C
928-
iszero(alpha) && return _rmul_or_fill!(C, beta)
929-
if alpha isa Bool
930-
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(true, beta))
931-
else
932-
@stable_muladdmul _bidimul!(C, A, B, MulAddMul(alpha, beta))
933-
end
930+
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
931+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
932+
_bidimul!(C, A, B, _add_nonzeroalpha)
934933
C
935934
end
936935
function _bidimul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
@@ -968,7 +967,15 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
968967
nA = size(A,1)
969968
nB = size(B,2)
970969
(iszero(nA) || iszero(nB)) && return C
971-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
970+
(; alpha, beta) = _add
971+
iszero(alpha) && return _rmul_or_fill!(C, beta)
972+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
973+
_mul_bitrisym_left!(C, A, B, _add_nonzeroalpha)
974+
return C
975+
end
976+
function _mul_bitrisym_left!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul)
977+
nA = size(A,1)
978+
nB = size(B,2)
972979
if nA <= 3
973980
# naive multiplication
974981
for I in CartesianIndices(C)
@@ -978,6 +985,7 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
978985
return C
979986
end
980987
_mul_bitrisym!(C, A, B, _add)
988+
return C
981989
end
982990
function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul)
983991
nA = size(A,1)
@@ -1037,7 +1045,15 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
10371045
matmul_size_check(size(C), size(A), size(B))
10381046
n = size(A,1)
10391047
m = size(B,2)
1040-
(iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta)
1048+
(; alpha, beta) = _add
1049+
(iszero(alpha) || iszero(m)) && return _rmul_or_fill!(C, beta)
1050+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1051+
_mul_bitrisym_right!(C, A, B, _add_nonzeroalpha)
1052+
C
1053+
end
1054+
function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
1055+
n = size(A,1)
1056+
m = size(B,2)
10411057
if m == 1
10421058
B11 = B[1,1]
10431059
return mul!(C, A, B11, _add.alpha, _add.beta)
@@ -1073,7 +1089,14 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
10731089
matmul_size_check(size(C), size(A), size(B))
10741090
m, n = size(A)
10751091
(iszero(m) || iszero(n)) && return C
1076-
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
1092+
(; alpha, beta) = _add
1093+
iszero(alpha) && return _rmul_or_fill!(C, beta)
1094+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1095+
_mul_bitrisym_right!(C, A, B, _add_nonzeroalpha)
1096+
C
1097+
end
1098+
function _mul_bitrisym_right!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
1099+
m, n = size(A)
10771100
@inbounds if B.uplo == 'U'
10781101
for j in n:-1:2, i in 1:m
10791102
_modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j))
@@ -1106,6 +1129,12 @@ function _dibimul!(C, A, B, _add)
11061129
# ensure that we fill off-band elements in the destination
11071130
_rmul_or_fill!(C, _add.beta)
11081131
iszero(_add.alpha) && return C
1132+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1133+
_dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
1134+
C
1135+
end
1136+
function _dibimul_nonzeroalpha!(C, A, B, _add)
1137+
n = size(A,1)
11091138
if n <= 3
11101139
# For simplicity, use a naive multiplication for small matrices
11111140
# that loops over all elements.
@@ -1142,14 +1171,8 @@ function _dibimul!(C, A, B, _add)
11421171
end # inbounds
11431172
C
11441173
end
1145-
function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
1146-
require_one_based_indexing(C)
1147-
matmul_size_check(size(C), size(A), size(B))
1174+
function _dibimul_nonzeroalpha!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add)
11481175
n = size(A,1)
1149-
iszero(n) && return C
1150-
# ensure that we fill off-band elements in the destination
1151-
_rmul_or_fill!(C, _add.beta)
1152-
iszero(_add.alpha) && return C
11531176
Ad = A.diag
11541177
Bdv, Bev = B.dv, B.ev
11551178
rowshift = B.uplo == 'U' ? -1 : 1
@@ -1179,6 +1202,11 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11791202
n = size(A,1)
11801203
n == 0 && return C
11811204
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
1205+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
1206+
_dibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
1207+
C
1208+
end
1209+
function _dibimul_nonzeroalpha!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11821210
Ad = A.diag
11831211
Bdv, Bev = B.dv, B.ev
11841212
Cdv, Cev = C.dv, C.ev

0 commit comments

Comments
 (0)