Skip to content

Commit 6ba90cc

Browse files
committed
still better than the fallback
1 parent 7aa9ed5 commit 6ba90cc

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/matmul.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ end
772772
# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
773773
# to be concretely inferred
774774
Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
775-
alpha::Number, beta::Number) where {T<:BlasFloat}
775+
α::Number, β::Number) where {T<:BlasFloat}
776776
nC = checksquare(C)
777777
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
778778
if tA_uc == 'T'
@@ -788,16 +788,18 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
788788

789789
# BLAS.syrk! only updates symmetric C
790790
# alternatively, make non-zero β a show-stopper for BLAS.syrk!
791-
if iszero(beta) || issymmetric(C)
792-
α, β = promote(alpha, beta, zero(T))
791+
if iszero(β) || issymmetric(C)
792+
alpha, beta = promote(α, β, zero(T))
793793
if (alpha isa Union{Bool,T} &&
794794
beta isa Union{Bool,T} &&
795795
stride(A, 1) == stride(C, 1) == 1 &&
796796
_fullstride2(A) && _fullstride2(C))
797797
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
798+
else
799+
return copytri!(generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta), 'U')
798800
end
799801
end
800-
return gemm_wrapper!(C, tA, tAt, A, A, alpha, beta)
802+
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
801803
end
802804
# legacy method
803805
syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
@@ -830,6 +832,8 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
830832
stride(A, 1) == stride(C, 1) == 1 &&
831833
_fullstride2(A) && _fullstride2(C))
832834
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
835+
else
836+
return copytri!(generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta), 'U', true)
833837
end
834838
end
835839
return gemm_wrapper!(C, tA, tAt, A, A, α, β)

0 commit comments

Comments
 (0)