Skip to content

Commit 0fa6a16

Browse files
committed
cleanup hermitian matrix functions
1 parent 7f354f4 commit 0fa6a16

File tree

1 file changed

+38
-113
lines changed

1 file changed

+38
-113
lines changed

src/symmetric.jl

Lines changed: 38 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,15 @@ const RealHermSymSymTri{T<:Real} = Union{RealHermSym{T}, SymTridiagonal{T}}
224224
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
225225
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
226226
const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}}
227-
const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}}
227+
const SelfAdjoint = Union{SymTridiagonal{<:Real}, Symmetric{<:Real}, Hermitian}
228228

229229
wrappertype(::Union{Symmetric, SymTridiagonal}) = Symmetric
230230
wrappertype(::Hermitian) = Hermitian
231231

232+
nonhermitianwrappertype(::SymSymTri{<:Real}) = Symmetric
233+
nonhermitianwrappertype(::Hermitian{<:Real}) = Symmetric
234+
nonhermitianwrappertype(::Hermitian) = identity
235+
232236
size(A::HermOrSym) = size(A.data)
233237
axes(A::HermOrSym) = axes(A.data)
234238
@inline function Base.isassigned(A::HermOrSym, i::Int, j::Int)
@@ -834,119 +838,65 @@ end
834838
^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
835839
^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p)
836840
^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p)
837-
function sympow(A::SymSymTri, p::Integer)
838-
if p < 0
839-
return Symmetric(Base.power_by_squaring(inv(A), -p))
840-
else
841-
return Symmetric(Base.power_by_squaring(A, p))
842-
end
843-
end
844-
for hermtype in (:Symmetric, :SymTridiagonal)
845-
@eval begin
846-
function ^(A::$hermtype{<:Real}, p::Real)
847-
isinteger(p) && return integerpow(A, p)
848-
F = eigen(A)
849-
if all-> λ 0, F.values)
850-
return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors')
851-
else
852-
return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors')
853-
end
854-
end
855-
function ^(A::$hermtype{<:Complex}, p::Real)
856-
isinteger(p) && return integerpow(A, p)
857-
return Symmetric(schurpow(A, p))
858-
end
859-
end
860-
end
861-
function ^(A::Hermitian, p::Integer)
841+
^(A::Hermitian, p::Integer) = sympow(A, p)
842+
function sympow(A, p::Integer)
862843
if p < 0
863-
retmat = Base.power_by_squaring(inv(A), -p)
844+
return wrappertype(A)(Base.power_by_squaring(inv(A), -p))
864845
else
865-
retmat = Base.power_by_squaring(A, p)
846+
return wrappertype(A)(Base.power_by_squaring(A, p))
866847
end
867-
return Hermitian(retmat)
868848
end
869-
function ^(A::Hermitian{T}, p::Real) where T
849+
function ^(A::SelfAdjoint, p::Real)
870850
isinteger(p) && return integerpow(A, p)
871851
F = eigen(A)
872852
if all-> λ 0, F.values)
873-
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
874-
return Hermitian(retmat)
853+
return wrappertype(A)((F.vectors * Diagonal((F.values).^p)) * F.vectors')
875854
else
876-
retmat = (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors'
877-
if T <: Real
878-
return Symmetric(retmat)
879-
else
880-
return retmat
881-
end
855+
return nonhermitianwrappertype(A)((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors')
882856
end
883857
end
858+
function ^(A::SymSymTri{<:Complex}, p::Real)
859+
isinteger(p) && return integerpow(A, p)
860+
return Symmetric(schurpow(A, p))
861+
end
884862

885863
for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
886864
@eval begin
887-
function ($func)(A::RealHermSymSymTri)
865+
function ($func)(A::SelfAdjoint)
888866
F = eigen(A)
889867
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
890868
end
891-
function ($func)(A::Hermitian{<:Complex})
892-
F = eigen(A)
893-
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
894-
return Hermitian(retmat)
895-
end
896869
end
897870
end
898871

899-
function cis(A::RealHermSymSymTri)
900-
F = eigen(A)
901-
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
902-
end
903-
function cis(A::Hermitian{<:Complex})
872+
function cis(A::SelfAdjoint)
904873
F = eigen(A)
905-
return F.vectors .* cis.(F.values') * F.vectors'
874+
return nonhermitianwrappertype(A)(F.vectors .* cis.(F.values') * F.vectors')
906875
end
907876

908-
909877
for func in (:acos, :asin)
910878
@eval begin
911-
function ($func)(A::RealHermSymSymTri)
879+
function ($func)(A::SelfAdjoint)
912880
F = eigen(A)
913881
if all-> -1 λ 1, F.values)
914882
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
915883
else
916-
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
917-
end
918-
end
919-
function ($func)(A::Hermitian{<:Complex})
920-
F = eigen(A)
921-
if all-> -1 λ 1, F.values)
922-
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
923-
return Hermitian(retmat)
924-
else
925-
return (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
884+
return nonhermitianwrappertype(A)((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
926885
end
927886
end
928887
end
929888
end
930889

931-
function acosh(A::RealHermSymSymTri)
890+
function acosh(A::SelfAdjoint)
932891
F = eigen(A)
933892
if all-> λ 1, F.values)
934893
return wrappertype(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
935894
else
936-
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
937-
end
938-
end
939-
function acosh(A::Hermitian{<:Complex})
940-
F = eigen(A)
941-
if all-> λ 1, F.values)
942-
retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors'
943-
return Hermitian(retmat)
944-
else
945-
return (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
895+
return nonhermitianwrappertype(A)((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
946896
end
947897
end
948898

949-
function sincos(A::RealHermSymSymTri)
899+
function sincos(A::SelfAdjoint)
950900
n = checksquare(A)
951901
F = eigen(A)
952902
T = float(eltype(F.values))
@@ -956,49 +906,24 @@ function sincos(A::RealHermSymSymTri)
956906
end
957907
return wrappertype(A)((F.vectors * S) * F.vectors'), wrappertype(A)((F.vectors * C) * F.vectors')
958908
end
959-
function sincos(A::Hermitian{<:Complex})
960-
n = checksquare(A)
909+
910+
function log(A::SelfAdjoint)
961911
F = eigen(A)
962-
T = float(eltype(F.values))
963-
S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
964-
for i in eachindex(S.diag, C.diag, F.values)
965-
S.diag[i], C.diag[i] = sincos(F.values[i])
966-
end
967-
retmatS, retmatC = (F.vectors * S) * F.vectors', (F.vectors * C) * F.vectors'
968-
for i in diagind(retmatS, IndexStyle(retmatS))
969-
retmatS[i] = real(retmatS[i])
970-
retmatC[i] = real(retmatC[i])
912+
if all-> λ > 0, F.values)
913+
return wrappertype(A)((F.vectors * Diagonal(log.(F.values))) * F.vectors')
914+
else
915+
return nonhermitianwrappertype(A)((F.vectors * Diagonal(log.(complex.(F.values)))) * F.vectors')
971916
end
972-
return Hermitian(retmatS), Hermitian(retmatC)
973917
end
974918

975-
976-
for func in (:log, :sqrt)
977-
# sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors
978-
rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[]
979-
rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0
980-
@eval begin
981-
function ($func)(A::RealHermSymSymTri{T}; $(rtolarg...)) where {T<:Real}
982-
F = eigen(A)
983-
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
984-
if all-> λ λ₀, F.values)
985-
return wrappertype(A)((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
986-
else
987-
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
988-
end
989-
end
990-
function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex}
991-
n = checksquare(A)
992-
F = eigen(A)
993-
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
994-
if all-> λ λ₀, F.values)
995-
retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors'
996-
return Hermitian(retmat)
997-
else
998-
retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
999-
return retmat
1000-
end
1001-
end
919+
# sqrt has rtol kwarg to handle matrices that are semidefinite up to roundoff errors
920+
function sqrt(A::SelfAdjoint; rtol = eps(real(float(eltype(A)))) * size(A, 1))
921+
F = eigen(A)
922+
λ₀ = -maximum(abs, F.values) * rtol # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
923+
if all-> λ λ₀, F.values)
924+
return wrappertype(A)((F.vectors * Diagonal(sqrt.(max.(0, F.values)))) * F.vectors')
925+
else
926+
return nonhermitianwrappertype(A)((F.vectors * Diagonal(sqrt.(complex.(F.values)))) * F.vectors')
1002927
end
1003928
end
1004929

0 commit comments

Comments
 (0)