Skip to content

Commit c1630b6

Browse files
committed
hang generic code lower in the dispatch hierarchy
1 parent 67a3bd7 commit c1630b6

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

src/matmul.jl

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -488,14 +488,13 @@ end
488488

489489
# THE one big BLAS dispatch. This is split into two methods to improve latency
490490
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
491-
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat}
491+
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number}
492492
mA, nA = lapack_size(tA, A)
493493
mB, nB = lapack_size(tB, B)
494494
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
495495
matmul_size_check(size(C), (mA, nA), (mB, nB))
496496
return _rmul_or_fill!(C, β)
497497
end
498-
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
499498
_syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val)
500499
return C
501500
end
@@ -570,29 +569,6 @@ Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha,
570569
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
571570
end
572571

573-
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
574-
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number}
575-
mA, nA = lapack_size(tA, A)
576-
mB, nB = lapack_size(tB, B)
577-
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
578-
matmul_size_check(size(C), (mA, nA), (mB, nB))
579-
return _rmul_or_fill!(C, β)
580-
end
581-
582-
if A === B
583-
tA_uc = uppercase(tA) # potentially strip a WrapperChar
584-
aat = (tA_uc == 'N')
585-
blasfn = _valtypeparam(val)
586-
if blasfn == BlasFlag.SYRK && T <: Union{Real,Complex} && (iszero(β) || issymmetric(C))
587-
return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U')
588-
elseif blasfn == BlasFlag.HERK && isreal(α) && isreal(β) && (iszero(β) || ishermitian(C))
589-
return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true)
590-
end
591-
end
592-
593-
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
594-
end
595-
596572
"""
597573
generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}
598574
@@ -800,14 +776,25 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
800776
if (alpha isa Union{Bool,T} &&
801777
beta isa Union{Bool,T} &&
802778
stride(A, 1) == stride(C, 1) == 1 &&
803-
_fullstride2(A) && _fullstride2(C))
779+
_fullstride2(A) && _fullstride2(C)) &&
780+
max(nA, mA) 4
804781
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
805782
else
806783
return copytri!(generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta), 'U')
807784
end
808785
end
809786
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
810787
end
788+
Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
789+
α::Number, β::Number) where {T<:Number}
790+
791+
tA_uc = uppercase(tA) # potentially strip a WrapperChar
792+
aat = (tA_uc == 'N')
793+
if T <: Union{Real,Complex} && (iszero(β) || issymmetric(C))
794+
return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U')
795+
end
796+
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
797+
end
811798
# legacy method
812799
syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
813800
syrk_wrapper!(C, tA, A, _add.alpha, _add.beta)
@@ -837,14 +824,26 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
837824
if (alpha isa Union{Bool,T} &&
838825
beta isa Union{Bool,T} &&
839826
stride(A, 1) == stride(C, 1) == 1 &&
840-
_fullstride2(A) && _fullstride2(C))
827+
_fullstride2(A) && _fullstride2(C)) &&
828+
max(nA, mA) 4
841829
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
842830
else
843831
return copytri!(generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta), 'U', true)
844832
end
845833
end
846834
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
847835
end
836+
Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
837+
α::Number, β::Number) where {T<:Number}
838+
839+
tA_uc = uppercase(tA) # potentially strip a WrapperChar
840+
aat = (tA_uc == 'N')
841+
842+
if isreal(α) && isreal(β) && (iszero(β) || ishermitian(C))
843+
return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true)
844+
end
845+
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
846+
end
848847
# legacy method
849848
herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
850849
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =
@@ -896,6 +895,13 @@ end
896895
gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
897896
A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
898897
gemm_wrapper!(C, tA, tB, A, B, _add.alpha, _add.beta)
898+
# fallback for generic types
899+
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
900+
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
901+
α::Number, β::Number) where {T<:Number}
902+
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
903+
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
904+
end
899905

900906
# Aggressive constprop helps propagate the values of tA and tB into wrap, which
901907
# makes the calls concretely inferred

0 commit comments

Comments
 (0)