@@ -488,14 +488,13 @@ end
488
488
489
489
# THE one big BLAS dispatch. This is split into two methods to improve latency
490
490
Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
491
- α:: Number , β:: Number , val:: BlasFlag.SyrkHerkGemm ) where {T<: BlasFloat }
491
+ α:: Number , β:: Number , val:: BlasFlag.SyrkHerkGemm ) where {T<: Number }
492
492
mA, nA = lapack_size (tA, A)
493
493
mB, nB = lapack_size (tB, B)
494
494
if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
495
495
matmul_size_check (size (C), (mA, nA), (mB, nB))
496
496
return _rmul_or_fill! (C, β)
497
497
end
498
- matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
499
498
_syrk_herk_gemm_wrapper! (C, tA, tB, A, B, α, β, val)
500
499
return C
501
500
end
@@ -570,6 +569,70 @@ Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha,
570
569
_generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
571
570
end
572
571
572
+ """
573
+ generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}
574
+
575
+ Computes syrk/herk for generic number types. If `conjugate` is false computes syrk, i.e.,
576
+ ``A transpose(A) α + C β`` if `aat` is true, and ``transpose(A) A α + C β`` otherwise.
577
+ If `conjugate` is true computes herk, i.e., ``A A' α + C β`` if `aat` is true, and
578
+ ``A' A α + C β`` otherwise. Only the upper triangular is computed.
579
+ """
580
+ function generic_syrk! (C:: StridedMatrix{T} , A:: StridedVecOrMat{T} , conjugate:: Bool , aat:: Bool , α, β) where {T<: Number }
581
+ require_one_based_indexing (C, A)
582
+ nC = checksquare (C)
583
+ m, n = size (A, 1 ), size (A, 2 )
584
+ mA = aat ? m : n
585
+ if nC != mA
586
+ throw (DimensionMismatch (lazy " output matrix has size: $(size(C)), but should have size $((mA, mA))" ))
587
+ end
588
+
589
+ _rmul_or_fill! (C, β)
590
+ @inbounds if ! conjugate
591
+ if aat
592
+ for k ∈ 1 : n, j ∈ 1 : m
593
+ αA_jk = A[j, k] * α
594
+ for i ∈ 1 : j
595
+ C[i, j] += A[i, k] * αA_jk
596
+ end
597
+ end
598
+ else
599
+ for j ∈ 1 : n, i ∈ 1 : j
600
+ temp = A[1 , i] * A[1 , j]
601
+ for k ∈ 2 : m
602
+ temp += A[k, i] * A[k, j]
603
+ end
604
+ C[i, j] += temp * α
605
+ end
606
+ end
607
+ else
608
+ if aat
609
+ for k ∈ 1 : n, j ∈ 1 : m
610
+ αA_jk_bar = conj (A[j, k]) * α
611
+ for i ∈ 1 : j- 1
612
+ C[i, j] += A[i, k] * αA_jk_bar
613
+ end
614
+ C[j, j] += abs2 (A[j, k]) * α
615
+ end
616
+ else
617
+ for j ∈ 1 : n
618
+ for i ∈ 1 : j- 1
619
+ temp = conj (A[1 , i]) * A[1 , j]
620
+ for k ∈ 2 : m
621
+ temp += conj (A[k, i]) * A[k, j]
622
+ end
623
+ C[i, j] += temp * α
624
+ end
625
+ temp = abs2 (A[1 , j])
626
+ for k ∈ 2 : m
627
+ temp += abs2 (A[k, j])
628
+ end
629
+ C[j, j] += temp * α
630
+ end
631
+ end
632
+ end
633
+ return C
634
+ end
635
+
573
636
# legacy method
574
637
Base. @constprop :aggressive generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
575
638
_add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
@@ -713,12 +776,27 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
713
776
if (alpha isa Union{Bool,T} &&
714
777
beta isa Union{Bool,T} &&
715
778
stride (A, 1 ) == stride (C, 1 ) == 1 &&
716
- _fullstride2 (A) && _fullstride2 (C))
717
- return copytri! (BLAS. syrk! (' U' , tA, alpha, A, beta, C), ' U' )
779
+ _fullstride2 (A) && _fullstride2 (C)) &&
780
+ max (nA, mA) ≥ 4
781
+ BLAS. syrk! (' U' , tA, alpha, A, beta, C)
782
+ else
783
+ generic_syrk! (C, A, false , tA_uc == ' N' , alpha, beta)
718
784
end
785
+ return copytri! (C, ' U' )
719
786
end
720
787
return gemm_wrapper! (C, tA, tAt, A, A, α, β)
721
788
end
789
+ Base. @constprop :aggressive function syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
790
+ α:: Number , β:: Number ) where {T<: Number }
791
+
792
+ tA_uc = uppercase (tA) # potentially strip a WrapperChar
793
+ aat = (tA_uc == ' N' )
794
+ if T <: Union{Real,Complex} && (iszero (β) || issymmetric (C))
795
+ return copytri! (generic_syrk! (C, A, false , aat, α, β), ' U' )
796
+ end
797
+ tAt = aat ? ' T' : ' N'
798
+ return _generic_matmatmul! (C, wrap (A, tA), wrap (A, tAt), α, β)
799
+ end
722
800
# legacy method
723
801
syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} , _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
724
802
syrk_wrapper! (C, tA, A, _add. alpha, _add. beta)
@@ -746,12 +824,27 @@ Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::Abs
746
824
alpha, beta = promote (α, β, zero (T))
747
825
if (alpha isa T && beta isa T &&
748
826
stride (A, 1 ) == stride (C, 1 ) == 1 &&
749
- _fullstride2 (A) && _fullstride2 (C))
750
- return copytri! (BLAS. herk! (' U' , tA, alpha, A, beta, C), ' U' , true )
827
+ _fullstride2 (A) && _fullstride2 (C)) &&
828
+ max (nA, mA) ≥ 4
829
+ BLAS. herk! (' U' , tA, alpha, A, beta, C)
830
+ else
831
+ generic_syrk! (C, A, true , tA_uc == ' N' , alpha, beta)
751
832
end
833
+ return copytri! (C, ' U' , true )
752
834
end
753
835
return gemm_wrapper! (C, tA, tAt, A, A, α, β)
754
836
end
837
+ Base. @constprop :aggressive function herk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
838
+ α:: Number , β:: Number ) where {T<: Number }
839
+
840
+ tA_uc = uppercase (tA) # potentially strip a WrapperChar
841
+ aat = (tA_uc == ' N' )
842
+ if isreal (α) && isreal (β) && (iszero (β) || ishermitian (C))
843
+ return copytri! (generic_syrk! (C, A, true , aat, α, β), ' U' , true )
844
+ end
845
+ tAt = aat ? ' C' : ' N'
846
+ return _generic_matmatmul! (C, wrap (A, tA), wrap (A, tAt), α, β)
847
+ end
755
848
# legacy method
756
849
herk_wrapper! (C:: Union{StridedMatrix{T}, StridedMatrix{Complex{T}}} , tA:: AbstractChar , A:: Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}} ,
757
850
_add:: MulAddMul = MulAddMul ()) where {T<: BlasReal } =
@@ -785,6 +878,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
785
878
mB, nB = lapack_size (tB, B)
786
879
787
880
matmul_size_check (size (C), (mA, nA), (mB, nB))
881
+ matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
788
882
789
883
if C === A || B === C
790
884
throw (ArgumentError (" output matrix must not be aliased with input matrix" ))
803
897
gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
804
898
A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} , _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
805
899
gemm_wrapper! (C, tA, tB, A, B, _add. alpha, _add. beta)
900
+ # fallback for generic types
901
+ Base. @constprop :aggressive function gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
902
+ A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
903
+ α:: Number , β:: Number ) where {T<: Number }
904
+ matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
905
+ return _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), α, β)
906
+ end
806
907
807
908
# Aggressive constprop helps propagate the values of tA and tB into wrap, which
808
909
# makes the calls concretely inferred
0 commit comments