Skip to content

Commit e65e75c

Browse files
authored
Clean up herk_wrapper! and add 5-arg tests (#1254)
This method can only be reached with complex eltypes, and after promotion of alpha and beta, they cannot be `Bool` anymore. Also, this adds 5-arg `mul!` of which I'm not sure we had some (even though coverage said it was covered, which is strange because it shouldn't due to the issue fixed in #1247).
1 parent ece1962 commit e65e75c

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

src/matmul.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,9 @@ syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add
725725

726726
# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
727727
# to be concretely inferred
728-
Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
729-
α::Number, β::Number) where {T<:BlasReal}
728+
Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::AbstractChar, A::StridedVecOrMat{TC},
729+
α::Number, β::Number) where {TC<:BlasComplex}
730+
T = real(TC)
730731
nC = checksquare(C)
731732
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
732733
if tA_uc == 'C'
@@ -740,13 +741,10 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
740741
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
741742
end
742743

743-
# Result array does not need to be initialized as long as beta==0
744-
# C = Matrix{T}(undef, mA, mA)
745-
744+
# BLAS.herk! only updates hermitian C, alpha and beta need to be real
746745
if iszero(β) || ishermitian(C)
747746
alpha, beta = promote(α, β, zero(T))
748-
if (alpha isa Union{Bool,T} &&
749-
beta isa Union{Bool,T} &&
747+
if (alpha isa T && beta isa T &&
750748
stride(A, 1) == stride(C, 1) == 1 &&
751749
_fullstride2(A) && _fullstride2(C))
752750
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)

test/matmul.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,17 @@ end
539539

540540
A5x5, A6x5 = Matrix{Float64}.(undef, ((5, 5), (6, 5)))
541541
@test_throws DimensionMismatch LinearAlgebra.syrk_wrapper!(A5x5, 'N', A6x5)
542-
@test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5)
542+
@test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(complex(A5x5), 'N', complex(A6x5))
543+
end
544+
545+
@testset "5-arg syrk! & herk!" begin
546+
for T in (Float32, Float64, ComplexF32, ComplexF64), A in (randn(T, 5), randn(T, 5, 5))
547+
B = A' * A
548+
C = B isa Number ? [B;;] : Matrix(Hermitian(B))
549+
@test mul!(copy(C), A', A, true, 2) 3C
550+
D = Matrix(Hermitian(A * A'))
551+
@test mul!(copy(D), A, A', true, 3) 4D
552+
end
543553
end
544554

545555
@testset "matmul for types w/o sizeof (issue #1282)" begin

0 commit comments

Comments
 (0)