Skip to content

Commit 4d21841

Browse files
authored
cleanup functions of Hermitian matrices (#55951)
The functions of Hermitian matrices are a bit of a mess. For example, if we have a Hermitian matrix `a` with negative eigenvalues, `a^0.5` doesn't produce the `Symmetric` wrapper, but `sqrt(a)` does. On the other hand, if we have a positive definite `b`, `b^0.5` will be `Hermitian`, but `sqrt(b)` will be `Symmetric`: ```julia using LinearAlgebra a = Hermitian([1.0 2.0;2.0 1.0]) a^0.5 sqrt(a) b = Hermitian([2.0 1.0; 1.0 2.0]) b^0.5 sqrt(b) ``` This sort of arbitrary assignment of wrappers happens with pretty much all functions defined there. There's also some oddities, such as `cis` being the only function defined for `SymTridiagonal`, even though all `eigen`-based functions work, and `cbrt` being the only function not defined for complex Hermitian matrices. I did a cleanup: I defined all functions for `SymTridiagonal` and `Hermitian{<:Complex}`, and always assigned the appropriate wrapper, preserving the input one when possible. There's an inconsistency remaining that I didn't fix, that only `sqrt` and `log` accept a tolerance argument, as changing that is probably breaking. There were also hardly any tests that I could find (only `exp`, `log`, `cis`, and `sqrt`). I'm happy to add them if it's desired.
1 parent c2a2e38 commit 4d21841

File tree

1 file changed

+93
-65
lines changed

1 file changed

+93
-65
lines changed

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 93 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -810,26 +810,32 @@ end
810810
# Matrix functions
811811
^(A::Symmetric{<:Real}, p::Integer) = sympow(A, p)
812812
^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
813-
function sympow(A::Symmetric, p::Integer)
814-
if p < 0
815-
return Symmetric(Base.power_by_squaring(inv(A), -p))
816-
else
817-
return Symmetric(Base.power_by_squaring(A, p))
818-
end
819-
end
820-
function ^(A::Symmetric{<:Real}, p::Real)
821-
isinteger(p) && return integerpow(A, p)
822-
F = eigen(A)
823-
if all-> λ 0, F.values)
824-
return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors')
825-
else
826-
return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors')
813+
^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p)
814+
^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p)
815+
for hermtype in (:Symmetric, :SymTridiagonal)
816+
@eval begin
817+
function sympow(A::$hermtype, p::Integer)
818+
if p < 0
819+
return Symmetric(Base.power_by_squaring(inv(A), -p))
820+
else
821+
return Symmetric(Base.power_by_squaring(A, p))
822+
end
823+
end
824+
function ^(A::$hermtype{<:Real}, p::Real)
825+
isinteger(p) && return integerpow(A, p)
826+
F = eigen(A)
827+
if all-> λ 0, F.values)
828+
return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors')
829+
else
830+
return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors')
831+
end
832+
end
833+
function ^(A::$hermtype{<:Complex}, p::Real)
834+
isinteger(p) && return integerpow(A, p)
835+
return Symmetric(schurpow(A, p))
836+
end
827837
end
828838
end
829-
function ^(A::Symmetric{<:Complex}, p::Real)
830-
isinteger(p) && return integerpow(A, p)
831-
return Symmetric(schurpow(A, p))
832-
end
833839
function ^(A::Hermitian, p::Integer)
834840
if p < 0
835841
retmat = Base.power_by_squaring(inv(A), -p)
@@ -855,16 +861,25 @@ function ^(A::Hermitian{T}, p::Real) where T
855861
return Hermitian(retmat)
856862
end
857863
else
858-
return (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors'
864+
retmat = (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors'
865+
if T <: Real
866+
return Symmetric(retmat)
867+
else
868+
return retmat
869+
end
859870
end
860871
end
861872

862-
for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh)
863-
@eval begin
864-
function ($func)(A::HermOrSym{<:Real})
865-
F = eigen(A)
866-
return Symmetric((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
873+
for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
874+
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
875+
@eval begin
876+
function ($func)(A::$hermtype{<:Real})
877+
F = eigen(A)
878+
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
879+
end
867880
end
881+
end
882+
@eval begin
868883
function ($func)(A::Hermitian{<:Complex})
869884
n = checksquare(A)
870885
F = eigen(A)
@@ -877,23 +892,34 @@ for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh)
877892
end
878893
end
879894

880-
function cis(A::Union{RealHermSymComplexHerm,SymTridiagonal{<:Real}})
895+
for wrapper in (:Symmetric, :Hermitian, :SymTridiagonal)
896+
@eval begin
897+
function cis(A::$wrapper{<:Real})
898+
F = eigen(A)
899+
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
900+
end
901+
end
902+
end
903+
function cis(A::Hermitian{<:Complex})
881904
F = eigen(A)
882-
# The returned matrix is unitary, and is complex-symmetric for real A
883905
return F.vectors .* cis.(F.values') * F.vectors'
884906
end
885907

908+
886909
for func in (:acos, :asin)
887-
@eval begin
888-
function ($func)(A::HermOrSym{<:Real})
889-
F = eigen(A)
890-
if all-> -1 λ 1, F.values)
891-
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
892-
else
893-
retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
910+
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
911+
@eval begin
912+
function ($func)(A::$hermtype{<:Real})
913+
F = eigen(A)
914+
if all-> -1 λ 1, F.values)
915+
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
916+
else
917+
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
918+
end
894919
end
895-
return Symmetric(retmat)
896920
end
921+
end
922+
@eval begin
897923
function ($func)(A::Hermitian{<:Complex})
898924
n = checksquare(A)
899925
F = eigen(A)
@@ -910,14 +936,17 @@ for func in (:acos, :asin)
910936
end
911937
end
912938

913-
function acosh(A::HermOrSym{<:Real})
914-
F = eigen(A)
915-
if all-> λ 1, F.values)
916-
retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors'
917-
else
918-
retmat = (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
939+
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
940+
@eval begin
941+
function acosh(A::$hermtype{<:Real})
942+
F = eigen(A)
943+
if all-> λ 1, F.values)
944+
return $wrapper((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
945+
else
946+
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
947+
end
948+
end
919949
end
920-
return Symmetric(retmat)
921950
end
922951
function acosh(A::Hermitian{<:Complex})
923952
n = checksquare(A)
@@ -933,14 +962,18 @@ function acosh(A::Hermitian{<:Complex})
933962
end
934963
end
935964

936-
function sincos(A::HermOrSym{<:Real})
937-
n = checksquare(A)
938-
F = eigen(A)
939-
S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,)))
940-
for i in 1:n
941-
S.diag[i], C.diag[i] = sincos(F.values[i])
965+
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
966+
@eval begin
967+
function sincos(A::$hermtype{<:Real})
968+
n = checksquare(A)
969+
F = eigen(A)
970+
S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,)))
971+
for i in 1:n
972+
S.diag[i], C.diag[i] = sincos(F.values[i])
973+
end
974+
return $wrapper((F.vectors * S) * F.vectors'), $wrapper((F.vectors * C) * F.vectors')
975+
end
942976
end
943-
return Symmetric((F.vectors * S) * F.vectors'), Symmetric((F.vectors * C) * F.vectors')
944977
end
945978
function sincos(A::Hermitian{<:Complex})
946979
n = checksquare(A)
@@ -962,18 +995,20 @@ for func in (:log, :sqrt)
962995
# sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors
963996
rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[]
964997
rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0
965-
@eval begin
966-
function ($func)(A::HermOrSym{T}; $(rtolarg...)) where {T<:Real}
967-
F = eigen(A)
968-
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
969-
if all-> λ λ₀, F.values)
970-
retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors'
971-
else
972-
retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
998+
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
999+
@eval begin
1000+
function ($func)(A::$hermtype{T}; $(rtolarg...)) where {T<:Real}
1001+
F = eigen(A)
1002+
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
1003+
if all-> λ λ₀, F.values)
1004+
return $wrapper((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
1005+
else
1006+
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
1007+
end
9731008
end
974-
return Symmetric(retmat)
9751009
end
976-
1010+
end
1011+
@eval begin
9771012
function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex}
9781013
n = checksquare(A)
9791014
F = eigen(A)
@@ -992,13 +1027,6 @@ for func in (:log, :sqrt)
9921027
end
9931028
end
9941029

995-
# Cube root of a real-valued symmetric matrix
996-
function cbrt(A::HermOrSym{<:Real})
997-
F = eigen(A)
998-
A = F.vectors * Diagonal(cbrt.(F.values)) * F.vectors'
999-
return A
1000-
end
1001-
10021030
"""
10031031
hermitianpart(A::AbstractMatrix, uplo::Symbol=:U) -> Hermitian
10041032

0 commit comments

Comments
 (0)