Skip to content

Commit a32a281

Browse files
araujomsdkarraschjishnub
authored
add generic syrk/herk (#1249)
I've added an implementation of syrk/herk for generic types, in order to avoid falling back to `_generic_matmatmul!`, as it's rather slow. I didn't do anything fancy, no multithreading or anything, but this gives a 1.5x to 2x speedup for `Int` and `BigFloat`, for example. The generic version of syrk only works for real and complex numbers, but funnily enough herk works for anything that respects `conj(a*b) == conj(b)*conj(a)`, which as far as I can tell is any subtype of `Number`, including quaternions and octonions. I've ran the tests locally by reverting to the commit before the lazy JLLs one, and they pass. --------- Co-authored-by: Daniel Karrasch <[email protected]> Co-authored-by: Jishnu Bhattacharya <[email protected]>
1 parent cab0dc6 commit a32a281

File tree

3 files changed

+145
-6
lines changed

3 files changed

+145
-6
lines changed

src/generic.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ julia> rmul!([NaN], 0.0)
296296
```
297297
"""
298298
function rmul!(X::AbstractArray, s::Number)
299+
isone(s) && return X
299300
@simd for I in eachindex(X)
300301
@inbounds X[I] *= s
301302
end
@@ -334,6 +335,7 @@ julia> lmul!(0.0, [Inf])
334335
```
335336
"""
336337
function lmul!(s::Number, X::AbstractArray)
338+
isone(s) && return X
337339
@simd for I in eachindex(X)
338340
@inbounds X[I] = s*X[I]
339341
end

src/matmul.jl

Lines changed: 107 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -488,14 +488,13 @@ end
488488

489489
# THE one big BLAS dispatch. This is split into two methods to improve latency
490490
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
491-
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat}
491+
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number}
492492
mA, nA = lapack_size(tA, A)
493493
mB, nB = lapack_size(tB, B)
494494
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
495495
matmul_size_check(size(C), (mA, nA), (mB, nB))
496496
return _rmul_or_fill!(C, β)
497497
end
498-
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
499498
_syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val)
500499
return C
501500
end
@@ -570,6 +569,70 @@ Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha,
570569
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
571570
end
572571

572+
"""
573+
generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}
574+
575+
Computes syrk/herk for generic number types. If `conjugate` is false computes syrk, i.e.,
576+
``A transpose(A) α + C β`` if `aat` is true, and ``transpose(A) A α + C β`` otherwise.
577+
If `conjugate` is true computes herk, i.e., ``A A' α + C β`` if `aat` is true, and
578+
``A' A α + C β`` otherwise. Only the upper triangular is computed.
579+
"""
580+
function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}
581+
require_one_based_indexing(C, A)
582+
nC = checksquare(C)
583+
m, n = size(A, 1), size(A, 2)
584+
mA = aat ? m : n
585+
if nC != mA
586+
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
587+
end
588+
589+
_rmul_or_fill!(C, β)
590+
@inbounds if !conjugate
591+
if aat
592+
for k 1:n, j 1:m
593+
αA_jk = A[j, k] * α
594+
for i 1:j
595+
C[i, j] += A[i, k] * αA_jk
596+
end
597+
end
598+
else
599+
for j 1:n, i 1:j
600+
temp = A[1, i] * A[1, j]
601+
for k 2:m
602+
temp += A[k, i] * A[k, j]
603+
end
604+
C[i, j] += temp * α
605+
end
606+
end
607+
else
608+
if aat
609+
for k 1:n, j 1:m
610+
αA_jk_bar = conj(A[j, k]) * α
611+
for i 1:j-1
612+
C[i, j] += A[i, k] * αA_jk_bar
613+
end
614+
C[j, j] += abs2(A[j, k]) * α
615+
end
616+
else
617+
for j 1:n
618+
for i 1:j-1
619+
temp = conj(A[1, i]) * A[1, j]
620+
for k 2:m
621+
temp += conj(A[k, i]) * A[k, j]
622+
end
623+
C[i, j] += temp * α
624+
end
625+
temp = abs2(A[1, j])
626+
for k 2:m
627+
temp += abs2(A[k, j])
628+
end
629+
C[j, j] += temp * α
630+
end
631+
end
632+
end
633+
return C
634+
end
635+
573636
# legacy method
574637
Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
575638
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
@@ -713,12 +776,27 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
713776
if (alpha isa Union{Bool,T} &&
714777
beta isa Union{Bool,T} &&
715778
stride(A, 1) == stride(C, 1) == 1 &&
716-
_fullstride2(A) && _fullstride2(C))
717-
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
779+
_fullstride2(A) && _fullstride2(C)) &&
780+
max(nA, mA) 4
781+
BLAS.syrk!('U', tA, alpha, A, beta, C)
782+
else
783+
generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta)
718784
end
785+
return copytri!(C, 'U')
719786
end
720787
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
721788
end
789+
Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
790+
α::Number, β::Number) where {T<:Number}
791+
792+
tA_uc = uppercase(tA) # potentially strip a WrapperChar
793+
aat = (tA_uc == 'N')
794+
if T <: Union{Real,Complex} && (iszero(β) || issymmetric(C))
795+
return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U')
796+
end
797+
tAt = aat ? 'T' : 'N'
798+
return _generic_matmatmul!(C, wrap(A, tA), wrap(A, tAt), α, β)
799+
end
722800
# legacy method
723801
syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
724802
syrk_wrapper!(C, tA, A, _add.alpha, _add.beta)
@@ -746,12 +824,27 @@ Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::Abs
746824
alpha, beta = promote(α, β, zero(T))
747825
if (alpha isa T && beta isa T &&
748826
stride(A, 1) == stride(C, 1) == 1 &&
749-
_fullstride2(A) && _fullstride2(C))
750-
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
827+
_fullstride2(A) && _fullstride2(C)) &&
828+
max(nA, mA) 4
829+
BLAS.herk!('U', tA, alpha, A, beta, C)
830+
else
831+
generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta)
751832
end
833+
return copytri!(C, 'U', true)
752834
end
753835
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
754836
end
837+
Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
838+
α::Number, β::Number) where {T<:Number}
839+
840+
tA_uc = uppercase(tA) # potentially strip a WrapperChar
841+
aat = (tA_uc == 'N')
842+
if isreal(α) && isreal(β) && (iszero(β) || ishermitian(C))
843+
return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true)
844+
end
845+
tAt = aat ? 'C' : 'N'
846+
return _generic_matmatmul!(C, wrap(A, tA), wrap(A, tAt), α, β)
847+
end
755848
# legacy method
756849
herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
757850
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =
@@ -785,6 +878,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
785878
mB, nB = lapack_size(tB, B)
786879

787880
matmul_size_check(size(C), (mA, nA), (mB, nB))
881+
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
788882

789883
if C === A || B === C
790884
throw(ArgumentError("output matrix must not be aliased with input matrix"))
@@ -803,6 +897,13 @@ end
803897
gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
804898
A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
805899
gemm_wrapper!(C, tA, tB, A, B, _add.alpha, _add.beta)
900+
# fallback for generic types
901+
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
902+
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
903+
α::Number, β::Number) where {T<:Number}
904+
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
905+
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
906+
end
806907

807908
# Aggressive constprop helps propagate the values of tA and tB into wrap, which
808909
# makes the calls concretely inferred

test/generic.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,42 @@ end
123123
@test_throws DimensionMismatch axpy!(α, x, Vector(1:3), y, Vector(1:5))
124124
end
125125

126+
@testset "generic syrk & herk" begin
127+
for T (BigFloat, Complex{BigFloat}, Quaternion{Float64})
128+
α = randn(T)
129+
a = randn(T, 3, 4)
130+
csmall = similar(a, 3, 3)
131+
csmall_fallback = similar(a, 3, 3)
132+
cbig = similar(a, 4, 4)
133+
cbig_fallback = similar(a, 4, 4)
134+
mul!(csmall, a, a', real(α), false)
135+
LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', real(α), false)
136+
@test ishermitian(csmall)
137+
@test csmall csmall_fallback
138+
mul!(cbig, a', a, real(α), false)
139+
LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, real(α), false)
140+
@test ishermitian(cbig)
141+
@test cbig cbig_fallback
142+
mul!(csmall, a, transpose(a), α, false)
143+
LinearAlgebra._generic_matmatmul!(csmall_fallback, a, transpose(a), α, false)
144+
@test csmall csmall_fallback
145+
mul!(cbig, transpose(a), a, α, false)
146+
LinearAlgebra._generic_matmatmul!(cbig_fallback, transpose(a), a, α, false)
147+
@test cbig cbig_fallback
148+
if T <: Union{Real, Complex}
149+
@test issymmetric(csmall)
150+
@test issymmetric(cbig)
151+
end
152+
#make sure generic herk is not called for non-real α
153+
mul!(csmall, a, a', α, false)
154+
LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', α, false)
155+
@test csmall csmall_fallback
156+
mul!(cbig, a', a, α, false)
157+
LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, α, false)
158+
@test cbig cbig_fallback
159+
end
160+
end
161+
126162
@test !issymmetric(fill(1,5,3))
127163
@test !ishermitian(fill(1,5,3))
128164
@test (x = fill(1,3); cross(x,x) == zeros(3))

0 commit comments

Comments
 (0)