@@ -613,6 +613,27 @@ function _diag(A::Bidiagonal, k)
613613 end
614614end
615615
616+ """
617+ _MulAddMul_nonzeroalpha(_add::MulAddMul[, ::Val{false}])
618+
619+ Return a new `MulAddMul` with the value of `alpha` potentially set to a literal non-zero
620+ value if permitted by the type (e.g., for `_add.alpha isa Bool`, in which case the `alpha` is
621+ set to `true` in the returned instance).
622+ In other cases, the single-argument call is a no-op and returns `_add` without modifications.
623+
624+ In addition, if `Val(false)` is provided as the second argument,
625+ `beta` is set to `false` in the returned `MulAddMul` instance.
626+ """
627+ _MulAddMul_nonzeroalpha (_add:: MulAddMul ) = _add
628+ function _MulAddMul_nonzeroalpha (_add:: MulAddMul{ais1,bis0,A} , :: Val{false} ) where {ais1,bis0,A}
629+ MulAddMul {ais1,true,A,Bool} (_add. alpha, false )
630+ end
631+ function _MulAddMul_nonzeroalpha (_add:: MulAddMul{ais1,bis0,Bool} ) where {ais1,bis0}
632+ (; beta) = _add
633+ MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
634+ end
635+ _MulAddMul_nonzeroalpha (_add:: MulAddMul{ais1,bis0,Bool} , :: Val{false} ) where {ais1,bis0} = MulAddMul ()
636+
616637_mul! (C:: AbstractMatrix , A:: BiTriSym , B:: TriSym , _add:: MulAddMul ) =
617638 _bibimul! (C, A, B, _add)
618639_mul! (C:: AbstractMatrix , A:: BiTriSym , B:: Bidiagonal , _add:: MulAddMul ) =
@@ -626,36 +647,54 @@ function _bibimul!(C, A, B, _add)
626647 # `_modify!` in the following loop will not update the
627648 # off-diagonal elements for non-zero beta.
628649 _rmul_or_fill! (C, _add. beta)
629- _iszero_alpha (_add) && return C
630- if n <= 3
650+ iszero (_add. alpha) && return C
651+ # beta is unused in _bibimul_nonzeroalpha!, so we set it to false
652+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add, Val (false ))
653+ _bibimul_nonzeroalpha! (C, A, B, _add_nonzeroalpha)
654+ C
655+ end
656+ function _bibimul_nonzeroalpha! (C, A, B, _add)
657+ n = size (A,1 )
658+ if n == 1
631659 # naive multiplication
632- for I in CartesianIndices (C)
633- C[I] += _add (sum (A[I[1 ], k] * B[k, I[2 ]] for k in axes (A,2 )))
634- end
660+ @inbounds C[1 ,1 ] += _add (A[1 ,1 ] * B[1 ,1 ])
635661 return C
636662 end
637663 @inbounds begin
638664 # first column of C
639665 C[1 ,1 ] += _add (A[1 ,1 ]* B[1 ,1 ] + A[1 , 2 ]* B[2 ,1 ])
640666 C[2 ,1 ] += _add (A[2 ,1 ]* B[1 ,1 ] + A[2 ,2 ]* B[2 ,1 ])
641- C[3 ,1 ] += _add (A[3 ,2 ]* B[2 ,1 ])
667+ if n >= 3
668+ C[3 ,1 ] += _add (A[3 ,2 ]* B[2 ,1 ])
669+ end
642670 # second column of C
643671 C[1 ,2 ] += _add (A[1 ,1 ]* B[1 ,2 ] + A[1 ,2 ]* B[2 ,2 ])
644- C[2 ,2 ] += _add (A[2 ,1 ]* B[1 ,2 ] + A[2 ,2 ]* B[2 ,2 ] + A[2 ,3 ]* B[3 ,2 ])
645- C[3 ,2 ] += _add (A[3 ,2 ]* B[2 ,2 ] + A[3 ,3 ]* B[3 ,2 ])
646- C[4 ,2 ] += _add (A[4 ,3 ]* B[3 ,2 ])
672+ C22 = A[2 ,1 ]* B[1 ,2 ] + A[2 ,2 ]* B[2 ,2 ]
673+ if n >= 3
674+ C[2 ,2 ] += _add (C22 + A[2 ,3 ]* B[3 ,2 ])
675+ C[3 ,2 ] += _add (A[3 ,2 ]* B[2 ,2 ] + A[3 ,3 ]* B[3 ,2 ])
676+ if n >= 4
677+ C[4 ,2 ] += _add (A[4 ,3 ]* B[3 ,2 ])
678+ end
679+ else
680+ C[2 ,2 ] += _add (C22)
681+ end
647682 end # inbounds
648683 # middle columns
649684 __bibimul! (C, A, B, _add)
650685 @inbounds begin
651- C[n- 3 ,n- 1 ] += _add (A[n- 3 ,n- 2 ]* B[n- 2 ,n- 1 ])
652- C[n- 2 ,n- 1 ] += _add (A[n- 2 ,n- 2 ]* B[n- 2 ,n- 1 ] + A[n- 2 ,n- 1 ]* B[n- 1 ,n- 1 ])
653- C[n- 1 ,n- 1 ] += _add (A[n- 1 ,n- 2 ]* B[n- 2 ,n- 1 ] + A[n- 1 ,n- 1 ]* B[n- 1 ,n- 1 ] + A[n- 1 ,n]* B[n,n- 1 ])
654- C[n, n- 1 ] += _add (A[n,n- 1 ]* B[n- 1 ,n- 1 ] + A[n,n]* B[n,n- 1 ])
686+ if n >= 4
687+ C[n- 3 ,n- 1 ] += _add (A[n- 3 ,n- 2 ]* B[n- 2 ,n- 1 ])
688+ C[n- 2 ,n- 1 ] += _add (A[n- 2 ,n- 2 ]* B[n- 2 ,n- 1 ] + A[n- 2 ,n- 1 ]* B[n- 1 ,n- 1 ])
689+ C[n- 1 ,n- 1 ] += _add (A[n- 1 ,n- 2 ]* B[n- 2 ,n- 1 ] + A[n- 1 ,n- 1 ]* B[n- 1 ,n- 1 ] + A[n- 1 ,n]* B[n,n- 1 ])
690+ C[n, n- 1 ] += _add (A[n,n- 1 ]* B[n- 1 ,n- 1 ] + A[n,n]* B[n,n- 1 ])
691+ end
655692 # last column of C
656- C[n- 2 , n] += _add (A[n- 2 ,n- 1 ]* B[n- 1 ,n])
657- C[n- 1 , n] += _add (A[n- 1 ,n- 1 ]* B[n- 1 ,n ] + A[n- 1 ,n]* B[n,n ])
658- C[n, n] += _add (A[n,n- 1 ]* B[n- 1 ,n ] + A[n,n]* B[n,n ])
693+ if n >= 3
694+ C[n- 2 , n] += _add (A[n- 2 ,n- 1 ]* B[n- 1 ,n])
695+ C[n- 1 , n] += _add (A[n- 1 ,n- 1 ]* B[n- 1 ,n ] + A[n- 1 ,n]* B[n,n ])
696+ C[n, n] += _add (A[n,n- 1 ]* B[n- 1 ,n ] + A[n,n]* B[n,n ])
697+ end
659698 end # inbounds
660699 C
661700end
@@ -696,9 +735,9 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
696735 Al = _diag (A, - 1 )
697736 Ad = _diag (A, 0 )
698737 Au = _diag (A, 1 )
699- Bd = _diag (B, 0 )
738+ Bd = B . dv
700739 if B. uplo == ' U'
701- Bu = _diag (B, 1 )
740+ Bu = B . ev
702741 @inbounds begin
703742 for j in 3 : n- 2
704743 Aj₋2j₋1 = Au[j- 2 ]
@@ -717,7 +756,7 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
717756 end
718757 end
719758 else # B.uplo == 'L'
720- Bl = _diag (B, - 1 )
759+ Bl = B . ev
721760 @inbounds begin
722761 for j in 3 : n- 2
723762 Aj₋1j = Au[j- 1 ]
@@ -743,9 +782,9 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
743782 Bl = _diag (B, - 1 )
744783 Bd = _diag (B, 0 )
745784 Bu = _diag (B, 1 )
746- Ad = _diag (A, 0 )
785+ Ad = A . dv
747786 if A. uplo == ' U'
748- Au = _diag (A, 1 )
787+ Au = A . ev
749788 @inbounds begin
750789 for j in 3 : n- 2
751790 Aj₋2j₋1 = Au[j- 2 ]
@@ -765,7 +804,7 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
765804 end
766805 end
767806 else # A.uplo == 'L'
768- Al = _diag (A, - 1 )
807+ Al = A . ev
769808 @inbounds begin
770809 for j in 3 : n- 2
771810 Aj₋1j₋1 = Ad[j- 1 ]
@@ -789,11 +828,11 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
789828end
790829function __bibimul! (C, A:: Bidiagonal , B:: Bidiagonal , _add)
791830 n = size (A,1 )
792- Ad = _diag (A, 0 )
793- Bd = _diag (B, 0 )
831+ Ad = A . dv
832+ Bd = B . dv
794833 if A. uplo == ' U' && B. uplo == ' U'
795- Au = _diag (A, 1 )
796- Bu = _diag (B, 1 )
834+ Au = A . ev
835+ Bu = B . ev
797836 @inbounds begin
798837 for j in 3 : n- 2
799838 Aj₋2j₋1 = Au[j- 2 ]
@@ -809,8 +848,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
809848 end
810849 end
811850 elseif A. uplo == ' U' && B. uplo == ' L'
812- Au = _diag (A, 1 )
813- Bl = _diag (B, - 1 )
851+ Au = A . ev
852+ Bl = B . ev
814853 @inbounds begin
815854 for j in 3 : n- 2
816855 Aj₋1j = Au[j- 1 ]
@@ -826,8 +865,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
826865 end
827866 end
828867 elseif A. uplo == ' L' && B. uplo == ' U'
829- Al = _diag (A, - 1 )
830- Bu = _diag (B, 1 )
868+ Al = A . ev
869+ Bu = B . ev
831870 @inbounds begin
832871 for j in 3 : n- 2
833872 Aj₋1j₋1 = Ad[j- 1 ]
@@ -843,8 +882,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
843882 end
844883 end
845884 else # A.uplo == 'L' && B.uplo == 'L'
846- Al = _diag (A, - 1 )
847- Bl = _diag (B, - 1 )
885+ Al = A . ev
886+ Bl = B . ev
848887 @inbounds begin
849888 for j in 3 : n- 2
850889 Ajj = Ad[j]
@@ -863,15 +902,20 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
863902 C
864903end
865904
866- _mul! (C:: AbstractMatrix , A:: BiTriSym , B:: Diagonal , alpha:: Number , beta:: Number ) =
867- @stable_muladdmul _mul! (C, A, B, MulAddMul (alpha, beta))
868905function _mul! (C:: AbstractMatrix , A:: BiTriSym , B:: Diagonal , _add:: MulAddMul )
869906 require_one_based_indexing (C)
870907 matmul_size_check (size (C), size (A), size (B))
871908 n = size (A,1 )
872909 iszero (n) && return C
873910 _rmul_or_fill! (C, _add. beta) # see the same use above
874- _iszero_alpha (_add) && return C
911+ iszero (_add. alpha) && return C
912+ # beta is unused in the _bidimul! call, so we set it to false
913+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add, Val (false ))
914+ _bidimul! (C, A, B, _add_nonzeroalpha)
915+ C
916+ end
917+ function _bidimul! (C:: AbstractMatrix , A:: BiTriSym , B:: Diagonal , _add:: MulAddMul )
918+ n = size (A,1 )
875919 Al = _diag (A, - 1 )
876920 Ad = _diag (A, 0 )
877921 Au = _diag (A, 1 )
@@ -907,14 +951,8 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
907951 end # inbounds
908952 C
909953end
910-
911- function _mul! (C:: AbstractMatrix , A:: Bidiagonal , B:: Diagonal , _add:: MulAddMul )
912- require_one_based_indexing (C)
913- matmul_size_check (size (C), size (A), size (B))
954+ function _bidimul! (C:: AbstractMatrix , A:: Bidiagonal , B:: Diagonal , _add:: MulAddMul )
914955 n = size (A,1 )
915- iszero (n) && return C
916- _rmul_or_fill! (C, _add. beta) # see the same use above
917- _iszero_alpha (_add) && return C
918956 (; dv, ev) = A
919957 Bd = B. diag
920958 rowshift = A. uplo == ' U' ? - 1 : 1
@@ -943,7 +981,13 @@ function _mul!(C::Bidiagonal, A::Bidiagonal, B::Diagonal, _add::MulAddMul)
943981 matmul_size_check (size (C), size (A), size (B))
944982 n = size (A,1 )
945983 iszero (n) && return C
946- _iszero_alpha (_add) && return _rmul_or_fill! (C, _add. beta)
984+ iszero (_add. alpha) && return _rmul_or_fill! (C, _add. beta)
985+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
986+ _bidimul! (C, A, B, _add_nonzeroalpha)
987+ C
988+ end
989+ function _bidimul! (C:: Bidiagonal , A:: Bidiagonal , B:: Diagonal , _add:: MulAddMul )
990+ n = size (A,1 )
947991 Adv, Aev = A. dv, A. ev
948992 Cdv, Cev = C. dv, C. ev
949993 Bd = B. diag
@@ -978,14 +1022,22 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
9781022 nB = size (B,2 )
9791023 (iszero (nA) || iszero (nB)) && return C
9801024 _iszero_alpha (_add) && return _rmul_or_fill! (C, _add. beta)
1025+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1026+ _mul_bitrisym_left! (C, A, B, _add_nonzeroalpha)
1027+ return C
1028+ end
1029+ function _mul_bitrisym_left! (C:: AbstractVecOrMat , A:: BiTriSym , B:: AbstractVecOrMat , _add:: MulAddMul )
1030+ nA = size (A,1 )
1031+ nB = size (B,2 )
9811032 if nA == 1
9821033 A11 = @inbounds A[1 ,1 ]
9831034 for i in axes (B, 2 )
9841035 @inbounds _modify! (_add, A11 * B[1 ,i], C, (1 ,i))
9851036 end
986- return C
1037+ else
1038+ _mul_bitrisym! (C, A, B, _add)
9871039 end
988- _mul_bitrisym! (C, A, B, _add)
1040+ return C
9891041end
9901042function _mul_bitrisym! (C:: AbstractVecOrMat , A:: Bidiagonal , B:: AbstractVecOrMat , _add:: MulAddMul )
9911043 nA = size (A,1 )
@@ -1046,6 +1098,13 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
10461098 n = size (A,1 )
10471099 m = size (B,2 )
10481100 (_iszero_alpha (_add) || iszero (m)) && return _rmul_or_fill! (C, _add. beta)
1101+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1102+ _mul_bitrisym_right! (C, A, B, _add_nonzeroalpha)
1103+ C
1104+ end
1105+ function _mul_bitrisym_right! (C:: AbstractMatrix , A:: AbstractMatrix , B:: TriSym , _add:: MulAddMul )
1106+ n = size (A,1 )
1107+ m = size (B,2 )
10491108 if m == 1
10501109 B11 = B[1 ,1 ]
10511110 return mul! (C, A, B11, _add. alpha, _add. beta)
@@ -1082,6 +1141,12 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
10821141 m, n = size (A)
10831142 (iszero (m) || iszero (n)) && return C
10841143 _iszero_alpha (_add) && return _rmul_or_fill! (C, _add. beta)
1144+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1145+ _mul_bitrisym_right! (C, A, B, _add_nonzeroalpha)
1146+ C
1147+ end
1148+ function _mul_bitrisym_right! (C:: AbstractMatrix , A:: AbstractMatrix , B:: Bidiagonal , _add:: MulAddMul )
1149+ m, n = size (A)
10851150 @inbounds if B. uplo == ' U'
10861151 for j in n: - 1 : 2 , i in 1 : m
10871152 _modify! (_add, A[i,j] * B. dv[j] + A[i,j- 1 ] * B. ev[j- 1 ], C, (i, j))
@@ -1114,6 +1179,13 @@ function _dibimul!(C, A, B, _add)
11141179 # ensure that we fill off-band elements in the destination
11151180 _rmul_or_fill! (C, _add. beta)
11161181 _iszero_alpha (_add) && return C
1182+ # beta is unused in the _dibimul_nonzeroalpha! call, so we set it to false
1183+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add, Val (false ))
1184+ _dibimul_nonzeroalpha! (C, A, B, _add_nonzeroalpha)
1185+ C
1186+ end
1187+ function _dibimul_nonzeroalpha! (C, A, B, _add)
1188+ n = size (A,1 )
11171189 if n <= 3
11181190 # For simplicity, use a naive multiplication for small matrices
11191191 # that loops over all elements.
@@ -1150,14 +1222,8 @@ function _dibimul!(C, A, B, _add)
11501222 end # inbounds
11511223 C
11521224end
1153- function _dibimul! (C:: AbstractMatrix , A:: Diagonal , B:: Bidiagonal , _add)
1154- require_one_based_indexing (C)
1155- matmul_size_check (size (C), size (A), size (B))
1225+ function _dibimul_nonzeroalpha! (C:: AbstractMatrix , A:: Diagonal , B:: Bidiagonal , _add)
11561226 n = size (A,1 )
1157- iszero (n) && return C
1158- # ensure that we fill off-band elements in the destination
1159- _rmul_or_fill! (C, _add. beta)
1160- _iszero_alpha (_add) && return C
11611227 Ad = A. diag
11621228 Bdv, Bev = B. dv, B. ev
11631229 rowshift = B. uplo == ' U' ? - 1 : 1
@@ -1187,6 +1253,11 @@ function _dibimul!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
11871253 n = size (A,1 )
11881254 n == 0 && return C
11891255 _iszero_alpha (_add) && return _rmul_or_fill! (C, _add. beta)
1256+ _add_nonzeroalpha = _MulAddMul_nonzeroalpha (_add)
1257+ _dibimul_nonzeroalpha! (C, A, B, _add_nonzeroalpha)
1258+ C
1259+ end
1260+ function _dibimul_nonzeroalpha! (C:: Bidiagonal , A:: Diagonal , B:: Bidiagonal , _add)
11901261 Ad = A. diag
11911262 Bdv, Bev = B. dv, B. ev
11921263 Cdv, Cev = C. dv, C. ev
0 commit comments