@@ -598,22 +598,27 @@ function _diag(A::Bidiagonal, k)
598598 end
599599end
600600
601+ _MulAddMul_nonzeroalpha (_add:: MulAddMul ) = _add
602+ function _MulAddMul_nonzeroalpha (_add:: MulAddMul{ais1,bis0,Bool} ) where {ais1,bis0}
603+ (; beta) = _add
604+ MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
605+ end
606+
601607_mul! (C:: AbstractMatrix , A:: BiTriSym , B:: TriSym , _add:: MulAddMul ) =
602608 _bibimul! (C, A, B, _add)
603609_mul! (C:: AbstractMatrix , A:: BiTriSym , B:: Bidiagonal , _add:: MulAddMul ) =
604610 _bibimul! (C, A, B, _add)
605- function _bibimul! (C, A, B, _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
611+ function _bibimul! (C, A, B, _add:: MulAddMul )
606612 require_one_based_indexing (C)
607613 matmul_size_check (size (C), size (A), size (B))
608614 n = size (A,1 )
609615 iszero (n) && return C
610616 # We use `_rmul_or_fill!` instead of `_modify!` here since using
611617 # `_modify!` in the following loop will not update the
612618 # off-diagonal elements for non-zero beta.
613- (; alpha, beta) = _add
614- _rmul_or_fill! (C, beta)
615- ais1 || (iszero (alpha) && return C)
616- _add_nonzeroalpha = alpha isa Bool ? MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta) : _add
619+ _rmul_or_fill! (C, _add. beta)
620+ iszero (_add. alpha) && return C
621+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
617622 _bibimul_nonzeroalpha! (C, A, B, _add_nonzeroalpha)
618623 C
619624end
@@ -855,18 +860,15 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
855860 C
856861end
857862
858- function _mul! (C:: AbstractMatrix , A:: BiTriSym , B:: Diagonal , alpha :: Number , beta :: Number )
863+ function _mul! (C:: AbstractMatrix , A:: BiTriSym , B:: Diagonal , _add :: MulAddMul )
859864 require_one_based_indexing (C)
860865 matmul_size_check (size (C), size (A), size (B))
861866 n = size (A,1 )
862867 iszero (n) && return C
863- _rmul_or_fill! (C, beta) # see the same use above
864- iszero (alpha) && return C
865- if alpha isa Bool
866- @stable_muladdmul _bidimul! (C, A, B, MulAddMul (true , beta))
867- else
868- @stable_muladdmul _bidimul! (C, A, B, MulAddMul (alpha, beta))
869- end
868+ _rmul_or_fill! (C, _add. beta) # see the same use above
869+ iszero (_add. alpha) && return C
870+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
871+ _bidimul! (C, A, B, _add_nonzeroalpha)
870872 C
871873end
872874function _bidimul! (C:: AbstractMatrix , A:: BiTriSym , B:: Diagonal , _add:: MulAddMul )
@@ -932,16 +934,13 @@ function _bidimul!(C::AbstractMatrix, A::Bidiagonal, B::Diagonal, _add::MulAddMu
932934 C
933935end
934936
935- function _mul! (C:: Bidiagonal , A:: Bidiagonal , B:: Diagonal , alpha :: Number , beta :: Number )
937+ function _mul! (C:: Bidiagonal , A:: Bidiagonal , B:: Diagonal , _add :: MulAddMul )
936938 matmul_size_check (size (C), size (A), size (B))
937939 n = size (A,1 )
938940 iszero (n) && return C
939- iszero (alpha) && return _rmul_or_fill! (C, beta)
940- if alpha isa Bool
941- @stable_muladdmul _bidimul! (C, A, B, MulAddMul (true , beta))
942- else
943- @stable_muladdmul _bidimul! (C, A, B, MulAddMul (alpha, beta))
944- end
941+ iszero (_add. alpha) && return _rmul_or_fill! (C, _add. beta)
942+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
943+ _bidimul! (C, A, B, _add_nonzeroalpha)
945944 C
946945end
947946function _bidimul! (C:: Bidiagonal , A:: Bidiagonal , B:: Diagonal , _add:: MulAddMul )
@@ -980,14 +979,22 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
980979 nB = size (B,2 )
981980 (iszero (nA) || iszero (nB)) && return C
982981 iszero (_add. alpha) && return _rmul_or_fill! (C, _add. beta)
982+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
983+ _mul_bitrisym_left! (C, A, B, _add_nonzeroalpha)
984+ return C
985+ end
986+ function _mul_bitrisym_left! (C:: AbstractVecOrMat , A:: Bidiagonal , B:: AbstractVecOrMat , _add:: MulAddMul )
987+ nA = size (A,1 )
988+ nB = size (B,2 )
983989 if nA == 1
984990 A11 = @inbounds A[1 ,1 ]
985991 for i in axes (B, 2 )
986992 @inbounds _modify! (_add, A11 * B[1 ,i], C, (1 ,i))
987993 end
988- return C
994+ else
995+ _mul_bitrisym! (C, A, B, _add)
989996 end
990- _mul_bitrisym! (C, A, B, _add)
997+ return C
991998end
992999function _mul_bitrisym! (C:: AbstractVecOrMat , A:: Bidiagonal , B:: AbstractVecOrMat , _add:: MulAddMul )
9931000 nA = size (A,1 )
@@ -1048,6 +1055,13 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
10481055 n = size (A,1 )
10491056 m = size (B,2 )
10501057 (iszero (_add. alpha) || iszero (m)) && return _rmul_or_fill! (C, _add. beta)
1058+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1059+ _mul_bitrisym_right! (C, A, B, _add_nonzeroalpha)
1060+ C
1061+ end
1062+ function _mul_bitrisym_right! (C:: AbstractMatrix , A:: AbstractMatrix , B:: TriSym , _add:: MulAddMul )
1063+ n = size (A,1 )
1064+ m = size (B,2 )
10511065 if m == 1
10521066 B11 = B[1 ,1 ]
10531067 return mul! (C, A, B11, _add. alpha, _add. beta)
@@ -1084,6 +1098,12 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
10841098 m, n = size (A)
10851099 (iszero (m) || iszero (n)) && return C
10861100 iszero (_add. alpha) && return _rmul_or_fill! (C, _add. beta)
1101+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1102+ _mul_bitrisym_right! (C, A, B, _add_nonzeroalpha)
1103+ C
1104+ end
1105+ function _mul_bitrisym_right! (C:: AbstractMatrix , A:: AbstractMatrix , B:: Bidiagonal , _add:: MulAddMul )
1106+ m, n = size (A)
10871107 @inbounds if B. uplo == ' U'
10881108 for j in n: - 1 : 2 , i in 1 : m
10891109 _modify! (_add, A[i,j] * B. dv[j] + A[i,j- 1 ] * B. ev[j- 1 ], C, (i, j))
@@ -1116,6 +1136,12 @@ function _dibimul!(C, A, B, _add)
11161136 # ensure that we fill off-band elements in the destination
11171137 _rmul_or_fill! (C, _add. beta)
11181138 iszero (_add. alpha) && return C
1139+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1140+ _dibimul_nonzeroalpha! (C, A, B, _add_nonzeroalpha)
1141+ C
1142+ end
1143+ function _dibimul_nonzeroalpha! (C, A, B, _add)
1144+ n = size (A,1 )
11191145 if n <= 3
11201146 # For simplicity, use a naive multiplication for small matrices
11211147 # that loops over all elements.
@@ -1152,14 +1178,8 @@ function _dibimul!(C, A, B, _add)
11521178 end # inbounds
11531179 C
11541180end
1155- function _dibimul! (C:: AbstractMatrix , A:: Diagonal , B:: Bidiagonal , _add)
1156- require_one_based_indexing (C)
1157- matmul_size_check (size (C), size (A), size (B))
1181+ function _dibimul_nonzeroalpha! (C:: AbstractMatrix , A:: Diagonal , B:: Bidiagonal , _add)
11581182 n = size (A,1 )
1159- iszero (n) && return C
1160- # ensure that we fill off-band elements in the destination
1161- _rmul_or_fill! (C, _add. beta)
1162- iszero (_add. alpha) && return C
11631183 Ad = A. diag
11641184 Bdv, Bev = B. dv, B. ev
11651185 rowshift = B. uplo == ' U' ? - 1 : 1
@@ -1189,6 +1209,11 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11891209 n = size (A,1 )
11901210 n == 0 && return C
11911211 iszero (_add. alpha) && return _rmul_or_fill! (C, _add. beta)
1212+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1213+ _dibimul_nonzeroalpha! (C, A, B, _add_nonzeroalpha)
1214+ C
1215+ end
1216+ function _dibimul_nonzeroalpha! (C:: Bidiagonal , A:: Diagonal , B:: Bidiagonal , _add)
11921217 Ad = A. diag
11931218 Bdv, Bev = B. dv, B. ev
11941219 Cdv, Cev = C. dv, C. ev
0 commit comments