Skip to content

Commit d60df5f

Browse files
committed
Move @stable_muladdmul within _bibimul!
1 parent bfc8613 commit d60df5f

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

src/bidiag.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -621,26 +621,29 @@ function _MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}) where {ais1,bi
621621
end
622622
_MulAddMul_nonzeroalpha(_add::MulAddMul{ais1,bis0,Bool}, ::Val{false}) where {ais1,bis0} = MulAddMul()
623623

624-
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
625-
_bibimul!(C, A, B, _add)
626-
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
627-
_bibimul!(C, A, B, _add)
628-
function _bibimul!(C, A, B, _add)
624+
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, alpha::Number, beta::Number) =
625+
_bibimul!(C, A, B, alpha, beta)
626+
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, alpha::Number, beta::Number) =
627+
_bibimul!(C, A, B, alpha, beta)
628+
function _bibimul!(C, A, B, alpha, beta)
629629
require_one_based_indexing(C)
630630
matmul_size_check(size(C), size(A), size(B))
631631
n = size(A,1)
632632
iszero(n) && return C
633633
# We use `_rmul_or_fill!` instead of `_modify!` here since using
634634
# `_modify!` in the following loop will not update the
635635
# off-diagonal elements for non-zero beta.
636-
_rmul_or_fill!(C, _add.beta)
637-
_iszero_alpha(_add) && return C
636+
_rmul_or_fill!(C, beta)
637+
iszero(alpha) && return C
638638
# beta is unused in _bibimul_nonzeroalpha!, so we set it to false
639-
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add, Val(false))
640-
_bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
639+
@stable_muladdmul _bibimul_nonzeroalpha!(C, A, B, MulAddMul(alpha, false))
641640
C
642641
end
643642
function _bibimul_nonzeroalpha!(C, A, B, _add)
643+
_add_nonzeroalpha = _MulAddMul_nonzeroalpha(_add)
644+
__bibimul_nonzeroalpha!(C, A, B, _add_nonzeroalpha)
645+
end
646+
function __bibimul_nonzeroalpha!(C, A, B, _add)
644647
n = size(A,1)
645648
if n == 1
646649
# naive multiplication
@@ -668,7 +671,7 @@ function _bibimul_nonzeroalpha!(C, A, B, _add)
668671
end
669672
end # inbounds
670673
# middle columns
671-
__bibimul!(C, A, B, _add)
674+
__bibimul_bulk!(C, A, B, _add)
672675
@inbounds begin
673676
if n >= 4
674677
C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1])
@@ -685,7 +688,7 @@ function _bibimul_nonzeroalpha!(C, A, B, _add)
685688
end # inbounds
686689
C
687690
end
688-
function __bibimul!(C, A, B, _add)
691+
function __bibimul_bulk!(C, A, B, _add)
689692
n = size(A,1)
690693
Al = _diag(A, -1)
691694
Ad = _diag(A, 0)
@@ -717,7 +720,7 @@ function __bibimul!(C, A, B, _add)
717720
end
718721
C
719722
end
720-
function __bibimul!(C, A, B::Bidiagonal, _add)
723+
function __bibimul_bulk!(C, A, B::Bidiagonal, _add)
721724
n = size(A,1)
722725
Al = _diag(A, -1)
723726
Ad = _diag(A, 0)
@@ -764,7 +767,7 @@ function __bibimul!(C, A, B::Bidiagonal, _add)
764767
end
765768
C
766769
end
767-
function __bibimul!(C, A::Bidiagonal, B, _add)
770+
function __bibimul_bulk!(C, A::Bidiagonal, B, _add)
768771
n = size(A,1)
769772
Bl = _diag(B, -1)
770773
Bd = _diag(B, 0)
@@ -813,7 +816,7 @@ function __bibimul!(C, A::Bidiagonal, B, _add)
813816
end
814817
C
815818
end
816-
function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
819+
function __bibimul_bulk!(C, A::Bidiagonal, B::Bidiagonal, _add)
817820
n = size(A,1)
818821
Ad = A.dv
819822
Bd = B.dv

0 commit comments

Comments
 (0)