@@ -488,14 +488,13 @@ end
488488
489489# THE one big BLAS dispatch. This is split into two methods to improve latency
490490Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
491- α:: Number , β:: Number , val:: BlasFlag.SyrkHerkGemm ) where {T<: BlasFloat }
491+ α:: Number , β:: Number , val:: BlasFlag.SyrkHerkGemm ) where {T<: Number }
492492 mA, nA = lapack_size (tA, A)
493493 mB, nB = lapack_size (tB, B)
494494 if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
495495 matmul_size_check (size (C), (mA, nA), (mB, nB))
496496 return _rmul_or_fill! (C, β)
497497 end
498- matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
499498 _syrk_herk_gemm_wrapper! (C, tA, tB, A, B, α, β, val)
500499 return C
501500end
@@ -570,29 +569,6 @@ Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha,
570569 _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
571570end
572571
573- Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
574- α:: Number , β:: Number , val:: BlasFlag.SyrkHerkGemm ) where {T<: Number }
575- mA, nA = lapack_size (tA, A)
576- mB, nB = lapack_size (tB, B)
577- if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
578- matmul_size_check (size (C), (mA, nA), (mB, nB))
579- return _rmul_or_fill! (C, β)
580- end
581-
582- if A === B
583- tA_uc = uppercase (tA) # potentially strip a WrapperChar
584- aat = (tA_uc == ' N' )
585- blasfn = _valtypeparam (val)
586- if blasfn == BlasFlag. SYRK && T <: Union{Real,Complex} && (iszero (β) || issymmetric (C))
587- return copytri! (generic_syrk! (C, A, false , aat, α, β), ' U' )
588- elseif blasfn == BlasFlag. HERK && isreal (α) && isreal (β) && (iszero (β) || ishermitian (C))
589- return copytri! (generic_syrk! (C, A, true , aat, α, β), ' U' , true )
590- end
591- end
592-
593- return _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), α, β)
594- end
595-
596572"""
597573 generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}
598574
@@ -800,14 +776,25 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
800776 if (alpha isa Union{Bool,T} &&
801777 beta isa Union{Bool,T} &&
802778 stride (A, 1 ) == stride (C, 1 ) == 1 &&
803- _fullstride2 (A) && _fullstride2 (C))
779+ _fullstride2 (A) && _fullstride2 (C)) &&
780+ max (nA, mA) ≥ 4
804781 return copytri! (BLAS. syrk! (' U' , tA, alpha, A, beta, C), ' U' )
805782 else
806783 return copytri! (generic_syrk! (C, A, false , tA_uc == ' N' , alpha, beta), ' U' )
807784 end
808785 end
809786 return gemm_wrapper! (C, tA, tAt, A, A, α, β)
810787end
788+ Base. @constprop :aggressive function syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
789+ α:: Number , β:: Number ) where {T<: Number }
790+
791+ tA_uc = uppercase (tA) # potentially strip a WrapperChar
792+ aat = (tA_uc == ' N' )
793+ if T <: Union{Real,Complex} && (iszero (β) || issymmetric (C))
794+ return copytri! (generic_syrk! (C, A, false , aat, α, β), ' U' )
795+ end
796+ return _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), α, β)
797+ end
811798# legacy method
812799syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} , _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
813800 syrk_wrapper! (C, tA, A, _add. alpha, _add. beta)
@@ -837,14 +824,26 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
837824 if (alpha isa Union{Bool,T} &&
838825 beta isa Union{Bool,T} &&
839826 stride (A, 1 ) == stride (C, 1 ) == 1 &&
840- _fullstride2 (A) && _fullstride2 (C))
827+ _fullstride2 (A) && _fullstride2 (C)) &&
828+ max (nA, mA) ≥ 4
841829 return copytri! (BLAS. herk! (' U' , tA, alpha, A, beta, C), ' U' , true )
842830 else
843831 return copytri! (generic_syrk! (C, A, true , tA_uc == ' N' , alpha, beta), ' U' , true )
844832 end
845833 end
846834 return gemm_wrapper! (C, tA, tAt, A, A, α, β)
847835end
836+ Base. @constprop :aggressive function herk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
837+ α:: Number , β:: Number ) where {T<: Number }
838+
839+ tA_uc = uppercase (tA) # potentially strip a WrapperChar
840+ aat = (tA_uc == ' N' )
841+
842+ if isreal (α) && isreal (β) && (iszero (β) || ishermitian (C))
843+ return copytri! (generic_syrk! (C, A, true , aat, α, β), ' U' , true )
844+ end
845+ return _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), α, β)
846+ end
848847# legacy method
849848herk_wrapper! (C:: Union{StridedMatrix{T}, StridedMatrix{Complex{T}}} , tA:: AbstractChar , A:: Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}} ,
850849 _add:: MulAddMul = MulAddMul ()) where {T<: BlasReal } =
896895gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
897896 A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} , _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
898897 gemm_wrapper! (C, tA, tB, A, B, _add. alpha, _add. beta)
898+ # fallback for generic types
899+ Base. @constprop :aggressive function gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
900+ A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
901+ α:: Number , β:: Number ) where {T<: Number }
902+ matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
903+ return _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), α, β)
904+ end
899905
900906# Aggressive constprop helps propagate the values of tA and tB into wrap, which
901907# makes the calls concretely inferred
0 commit comments