Skip to content

Commit 24120f1

Browse files
committed
Clean up herk_wrapper! and add 5-arg tests
1 parent e64a3df commit 24120f1

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

src/matmul.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,9 @@ syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add
725725

726726
# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
727727
# to be concretely inferred
728-
Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
729-
α::Number, β::Number) where {T<:BlasReal}
728+
Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::AbstractChar, A::StridedVecOrMat{TC},
729+
α::Number, β::Number) where {TC<:BlasComplex}
730+
T = real(TC)
730731
nC = checksquare(C)
731732
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
732733
if tA_uc == 'C'
@@ -740,13 +741,10 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
740741
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
741742
end
742743

743-
# Result array does not need to be initialized as long as beta==0
744-
# C = Matrix{T}(undef, mA, mA)
745-
744+
# BLAS.herk! only updates hermitian C, alpha and beta need to be real
746745
if iszero(β) || ishermitian(C)
747746
alpha, beta = promote(α, β, zero(T))
748-
if (alpha isa Union{Bool,T} &&
749-
beta isa Union{Bool,T} &&
747+
if (alpha isa T && beta isa T &&
750748
stride(A, 1) == stride(C, 1) == 1 &&
751749
_fullstride2(A) && _fullstride2(C))
752750
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)

test/matmul.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,16 @@ end
542542
@test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5)
543543
end
544544

545+
@testset "5-arg syrk! & herk!" begin
546+
for T in (Float32, Float64, ComplexF32, ComplexF64), A in (randn(T, 5), randn(T, 5, 5))
547+
B = A' * A
548+
C = B isa Number ? [B;;] : Matrix(Hermitian(B))
549+
@test mul!(copy(C), A', A, true, 2) 3C
550+
D = Matrix(Hermitian(A * A'))
551+
@test mul!(copy(D), A, A', true, 3) 4D
552+
end
553+
end
554+
545555
@testset "matmul for types w/o sizeof (issue #1282)" begin
546556
AA = fill(complex(1, 1), 10, 10)
547557
for A in (copy(AA), view(AA, 1:10, 1:10))

0 commit comments

Comments
 (0)