Skip to content

Commit b8a9330

Browse files
committed
Move @stable_muladdmul within _bibimul!
1 parent 3a9e3db commit b8a9330

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

src/bidiag.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -619,26 +619,29 @@ function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}) where {ais1,bi
619619
end
620620
_MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}, ::Val{false}) where {ais1,bis0} = MulAddMul()
621621

622-
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
623-
_bibimul!(C, A, B, _add)
624-
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
625-
_bibimul!(C, A, B, _add)
626-
function _bibimul!(C, A, B, _add)
622+
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, alpha::Number, beta::Number) =
623+
_bibimul!(C, A, B, alpha, beta)
624+
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, alpha::Number, beta::Number) =
625+
_bibimul!(C, A, B, alpha, beta)
626+
function _bibimul!(C, A, B, alpha, beta)
627627
require_one_based_indexing(C)
628628
matmul_size_check(size(C), size(A), size(B))
629629
n = size(A,1)
630630
iszero(n) && return C
631631
# We use `_rmul_or_fill!` instead of `_modify!` here since using
632632
# `_modify!` in the following loop will not update the
633633
# off-diagonal elements for non-zero beta.
634-
_rmul_or_fill!(C, _add.beta)
635-
iszero(_add.alpha) && return C
634+
_rmul_or_fill!(C, beta)
635+
iszero(alpha) && return C
636636
# beta is unused in _bibimul_nonzeroalpha!, so we set it to false
637-
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false))
638-
_bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
637+
@stable_muladdmul _bibimul_nonzeroalpha!(C, A, B, MulAddMul(alpha, false))
639638
C
640639
end
641640
function _bibimul_nonzeroalpha!(C, A, B, _add)
641+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
642+
__bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
643+
end
644+
function __bibimul_nonzeroalpha!(C, A, B, _add)
642645
n = size(A,1)
643646
if n == 1
644647
# naive multiplication
@@ -666,7 +669,7 @@ function _bibimul_nonzeroalpha!(C, A, B, _add)
666669
end
667670
end # inbounds
668671
# middle columns
669-
__bibimul!(C, A, B, _add)
672+
__bibimul_bulk!(C, A, B, _add)
670673
@inbounds begin
671674
if n >= 4
672675
C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1])
@@ -683,7 +686,7 @@ function _bibimul_nonzeroalpha!(C, A, B, _add)
683686
end # inbounds
684687
C
685688
end
686-
function __bibimul!(C, A, B, _add)
689+
function __bibimul_bulk!(C, A, B, _add)
687690
n = size(A,1)
688691
Al = _diag(A, -1)
689692
Ad = _diag(A, 0)
@@ -715,7 +718,7 @@ function __bibimul!(C, A, B, _add)
715718
end
716719
C
717720
end
718-
function __bibimul!(C, A, B::Bidiagonal, _add)
721+
function __bibimul_bulk!(C, A, B::Bidiagonal, _add)
719722
n = size(A,1)
720723
Al = _diag(A, -1)
721724
Ad = _diag(A, 0)
@@ -762,7 +765,7 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
762765
end
763766
C
764767
end
765-
function __bibimul!(C, A::Bidiagonal, B, _add)
768+
function __bibimul_bulk!(C, A::Bidiagonal, B, _add)
766769
n = size(A,1)
767770
Bl = _diag(B, -1)
768771
Bd = _diag(B, 0)
@@ -811,7 +814,7 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
811814
end
812815
C
813816
end
814-
function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
817+
function __bibimul_bulk!(C, A::Bidiagonal, B::Bidiagonal, _add)
815818
n = size(A,1)
816819
Ad = A.dv
817820
Bd = B.dv

0 commit comments

Comments
 (0)