From 0fa6a16a4fe603ae53ae145cd17ef47ba65fd9ac Mon Sep 17 00:00:00 2001 From: araujoms Date: Wed, 7 May 2025 14:24:16 +0200 Subject: [PATCH 1/5] cleanup hermitian matrix functions --- src/symmetric.jl | 151 ++++++++++++----------------------------------- 1 file changed, 38 insertions(+), 113 deletions(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index ed5fe3b5..af2d8bb6 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -224,11 +224,15 @@ const RealHermSymSymTri{T<:Real} = Union{RealHermSym{T}, SymTridiagonal{T}} const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}} const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}} const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}} -const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}} +const SelfAdjoint = Union{SymTridiagonal{<:Real}, Symmetric{<:Real}, Hermitian} wrappertype(::Union{Symmetric, SymTridiagonal}) = Symmetric wrappertype(::Hermitian) = Hermitian +nonhermitianwrappertype(::SymSymTri{<:Real}) = Symmetric +nonhermitianwrappertype(::Hermitian{<:Real}) = Symmetric +nonhermitianwrappertype(::Hermitian) = identity + size(A::HermOrSym) = size(A.data) axes(A::HermOrSym) = axes(A.data) @inline function Base.isassigned(A::HermOrSym, i::Int, j::Int) @@ -834,119 +838,65 @@ end ^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p) ^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p) ^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p) -function sympow(A::SymSymTri, p::Integer) - if p < 0 - return Symmetric(Base.power_by_squaring(inv(A), -p)) - else - return Symmetric(Base.power_by_squaring(A, p)) - end -end -for hermtype in (:Symmetric, :SymTridiagonal) - @eval begin - function ^(A::$hermtype{<:Real}, p::Real) - isinteger(p) && return integerpow(A, p) - F = eigen(A) - if all(λ -> λ ≥ 0, F.values) - return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors') - else - return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors') - end - end - function ^(A::$hermtype{<:Complex}, p::Real) - isinteger(p) && return integerpow(A, p) - return Symmetric(schurpow(A, p)) - end - end -end -function ^(A::Hermitian, p::Integer) +^(A::Hermitian, p::Integer) = sympow(A, p) +function sympow(A, p::Integer) if p < 0 - retmat = Base.power_by_squaring(inv(A), -p) + return wrappertype(A)(Base.power_by_squaring(inv(A), -p)) else - retmat = Base.power_by_squaring(A, p) + return wrappertype(A)(Base.power_by_squaring(A, p)) end - return Hermitian(retmat) end -function ^(A::Hermitian{T}, p::Real) where T +function ^(A::SelfAdjoint, p::Real) isinteger(p) && return integerpow(A, p) F = eigen(A) if all(λ -> λ ≥ 0, F.values) - retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors' - return Hermitian(retmat) + return wrappertype(A)((F.vectors * Diagonal((F.values).^p)) * F.vectors') else - retmat = (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors' - if T <: Real - return Symmetric(retmat) - else - return retmat - end + return nonhermitianwrappertype(A)((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors') end end +function ^(A::SymSymTri{<:Complex}, p::Real) + isinteger(p) && return integerpow(A, p) + return Symmetric(schurpow(A, p)) +end for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt) @eval begin - function ($func)(A::RealHermSymSymTri) + function ($func)(A::SelfAdjoint) F = eigen(A) return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors') end - function ($func)(A::Hermitian{<:Complex}) - F = eigen(A) - retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors' - return Hermitian(retmat) - end end end -function cis(A::RealHermSymSymTri) - F = eigen(A) - return Symmetric(F.vectors .* cis.(F.values') * F.vectors') -end -function cis(A::Hermitian{<:Complex}) +function cis(A::SelfAdjoint) F = eigen(A) - return F.vectors .* cis.(F.values') * F.vectors' + return nonhermitianwrappertype(A)(F.vectors .* cis.(F.values') * F.vectors') end - for func in (:acos, :asin) @eval begin - function ($func)(A::RealHermSymSymTri) + function ($func)(A::SelfAdjoint) F = eigen(A) if all(λ -> -1 ≤ λ ≤ 1, F.values) return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors') else - return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') - end - end - function ($func)(A::Hermitian{<:Complex}) - F = eigen(A) - if all(λ -> -1 ≤ λ ≤ 1, F.values) - retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors' - return Hermitian(retmat) - else - return (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors' + return nonhermitianwrappertype(A)((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') end end end end -function acosh(A::RealHermSymSymTri) +function acosh(A::SelfAdjoint) F = eigen(A) if all(λ -> λ ≥ 1, F.values) return wrappertype(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors') else - return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors') - end -end -function acosh(A::Hermitian{<:Complex}) - F = eigen(A) - if all(λ -> λ ≥ 1, F.values) - retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors' - return Hermitian(retmat) - else - return (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors' + return nonhermitianwrappertype(A)((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors') end end -function sincos(A::RealHermSymSymTri) +function sincos(A::SelfAdjoint) n = checksquare(A) F = eigen(A) T = float(eltype(F.values)) @@ -956,49 +906,24 @@ function sincos(A::RealHermSymSymTri) end return wrappertype(A)((F.vectors * S) * F.vectors'), wrappertype(A)((F.vectors * C) * F.vectors') end -function sincos(A::Hermitian{<:Complex}) - n = checksquare(A) + +function log(A::SelfAdjoint) F = eigen(A) - T = float(eltype(F.values)) - S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,))) - for i in eachindex(S.diag, C.diag, F.values) - S.diag[i], C.diag[i] = sincos(F.values[i]) - end - retmatS, retmatC = (F.vectors * S) * F.vectors', (F.vectors * C) * F.vectors' - for i in diagind(retmatS, IndexStyle(retmatS)) - retmatS[i] = real(retmatS[i]) - retmatC[i] = real(retmatC[i]) + if all(λ -> λ > 0, F.values) + return wrappertype(A)((F.vectors * Diagonal(log.(F.values))) * F.vectors') + else + return nonhermitianwrappertype(A)((F.vectors * Diagonal(log.(complex.(F.values)))) * F.vectors') end - return Hermitian(retmatS), Hermitian(retmatC) end - -for func in (:log, :sqrt) - # sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors - rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[] - rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0 - @eval begin - function ($func)(A::RealHermSymSymTri{T}; $(rtolarg...)) where {T<:Real} - F = eigen(A) - λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff - if all(λ -> λ ≥ λ₀, F.values) - return wrappertype(A)((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors') - else - return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') - end - end - function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex} - n = checksquare(A) - F = eigen(A) - λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff - if all(λ -> λ ≥ λ₀, F.values) - retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors' - return Hermitian(retmat) - else - retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors' - return retmat - end - end +# sqrt has rtol kwarg to handle matrices that are semidefinite up to roundoff errors +function sqrt(A::SelfAdjoint; rtol = eps(real(float(eltype(A)))) * size(A, 1)) + F = eigen(A) + λ₀ = -maximum(abs, F.values) * rtol # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff + if all(λ -> λ ≥ λ₀, F.values) + return wrappertype(A)((F.vectors * Diagonal(sqrt.(max.(0, F.values)))) * F.vectors') + else + return nonhermitianwrappertype(A)((F.vectors * Diagonal(sqrt.(complex.(F.values)))) * F.vectors') end end From 098966a81386e2915d007580045f6809f994f67b Mon Sep 17 00:00:00 2001 From: araujoms Date: Thu, 8 May 2025 13:09:55 +0200 Subject: [PATCH 2/5] formatting --- src/symmetric.jl | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index af2d8bb6..1d1ba5c7 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -850,9 +850,11 @@ function ^(A::SelfAdjoint, p::Real) isinteger(p) && return integerpow(A, p) F = eigen(A) if all(λ -> λ ≥ 0, F.values) - return wrappertype(A)((F.vectors * Diagonal((F.values).^p)) * F.vectors') + retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors' + return wrappertype(A)(retmat) else - return nonhermitianwrappertype(A)((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors') + retmat = (F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors' + return nonhermitianwrappertype(A)(retmat) end end function ^(A::SymSymTri{<:Complex}, p::Real) @@ -864,14 +866,16 @@ for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, @eval begin function ($func)(A::SelfAdjoint) F = eigen(A) - return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors') + retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors' + return wrappertype(A)(retmat) end end end function cis(A::SelfAdjoint) F = eigen(A) - return nonhermitianwrappertype(A)(F.vectors .* cis.(F.values') * F.vectors') + retmat = F.vectors .* cis.(F.values') * F.vectors' + return nonhermitianwrappertype(A)(retmat) end for func in (:acos, :asin) @@ -879,9 +883,11 @@ for func in (:acos, :asin) function ($func)(A::SelfAdjoint) F = eigen(A) if all(λ -> -1 ≤ λ ≤ 1, F.values) - return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors') + retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors' + return wrappertype(A)(retmat) else - return nonhermitianwrappertype(A)((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') + retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors' + return nonhermitianwrappertype(A)(retmat) end end end @@ -890,9 +896,11 @@ end function acosh(A::SelfAdjoint) F = eigen(A) if all(λ -> λ ≥ 1, F.values) - return wrappertype(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors') + retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors' + return wrappertype(A)(retmat) else - return nonhermitianwrappertype(A)((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors') + retmat = (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors' + return nonhermitianwrappertype(A)(retmat) end end @@ -910,9 +918,11 @@ end function log(A::SelfAdjoint) F = eigen(A) if all(λ -> λ > 0, F.values) - return wrappertype(A)((F.vectors * Diagonal(log.(F.values))) * F.vectors') + retmat = (F.vectors * Diagonal(log.(F.values))) * F.vectors' + return wrappertype(A)(retmat) else - return nonhermitianwrappertype(A)((F.vectors * Diagonal(log.(complex.(F.values)))) * F.vectors') + retmat = (F.vectors * Diagonal(log.(complex.(F.values)))) * F.vectors' + return nonhermitianwrappertype(A)(retmat) end end @@ -921,9 +931,11 @@ function sqrt(A::SelfAdjoint; rtol = eps(real(float(eltype(A)))) * size(A, 1)) F = eigen(A) λ₀ = -maximum(abs, F.values) * rtol # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff if all(λ -> λ ≥ λ₀, F.values) - return wrappertype(A)((F.vectors * Diagonal(sqrt.(max.(0, F.values)))) * F.vectors') + retmat = (F.vectors * Diagonal(sqrt.(max.(0, F.values)))) * F.vectors' + return wrappertype(A)(retmat) else - return nonhermitianwrappertype(A)((F.vectors * Diagonal(sqrt.(complex.(F.values)))) * F.vectors') + retmat = (F.vectors * Diagonal(sqrt.(complex.(F.values)))) * F.vectors' + return nonhermitianwrappertype(A)(retmat) end end From c4ef85eba0b79f51bbf2a7a44129a431fc96013f Mon Sep 17 00:00:00 2001 From: araujoms Date: Sun, 11 May 2025 12:35:00 +0200 Subject: [PATCH 3/5] formatting --- src/symmetric.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index 1d1ba5c7..d5278840 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -841,10 +841,11 @@ end ^(A::Hermitian, p::Integer) = sympow(A, p) function sympow(A, p::Integer) if p < 0 - return wrappertype(A)(Base.power_by_squaring(inv(A), -p)) + retmat = Base.power_by_squaring(inv(A), -p) else - return wrappertype(A)(Base.power_by_squaring(A, p)) + retmat = Base.power_by_squaring(A, p) end + return wrappertype(A)(retmat) end function ^(A::SelfAdjoint, p::Real) isinteger(p) && return integerpow(A, p) From d1678eb7047ec4362553c343c6520d9a6b9d492d Mon Sep 17 00:00:00 2001 From: araujoms Date: Sun, 11 May 2025 13:25:18 +0200 Subject: [PATCH 4/5] add test for complex asin/acos/acosh --- test/symmetric.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/symmetric.jl b/test/symmetric.jl index 9f727b8c..3f1397c0 100644 --- a/test/symmetric.jl +++ b/test/symmetric.jl @@ -1199,4 +1199,14 @@ end end end +@testset "asin/acos/acosh for matrix outside the real domain" begin + M = [0 2;2 0] #eigenvalues are ±2 + for T ∈ (Float32, Float64, ComplexF32, ComplexF64) + M2 = Hermitian(T.(M)) + @test sin(asin(M2)) ≈ M2 + @test cos(acos(M2)) ≈ M2 + @test cosh(acosh(M2)) ≈ M2 + end +end + end # module TestSymmetric From 69ddee806ad1b315e159ff0288437e486884fdf0 Mon Sep 17 00:00:00 2001 From: araujoms Date: Mon, 12 May 2025 14:47:50 +0200 Subject: [PATCH 5/5] no point in avoiding 0 --- src/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/symmetric.jl b/src/symmetric.jl index d5278840..6d928dd1 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -918,7 +918,7 @@ end function log(A::SelfAdjoint) F = eigen(A) - if all(λ -> λ > 0, F.values) + if all(λ -> λ ≥ 0, F.values) retmat = (F.vectors * Diagonal(log.(F.values))) * F.vectors' return wrappertype(A)(retmat) else