Skip to content

Commit fb5e96a

Browse files
authored
Merge identical methods for Symmetric/Hermitian and SymTridiagonal (JuliaLang#56434)
Since the methods do identical things, we may define each method once for a union of types instead of defining methods for each type.
1 parent 435516d commit fb5e96a

File tree

1 file changed

+65
-86
lines changed

1 file changed

+65
-86
lines changed

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 65 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,16 @@ convert(::Type{T}, m::Union{Symmetric,Hermitian}) where {T<:Hermitian} = m isa T
219219

220220
const HermOrSym{T, S} = Union{Hermitian{T,S}, Symmetric{T,S}}
221221
const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
222+
const SymSymTri{T} = Union{Symmetric{T}, SymTridiagonal{T}}
223+
const RealHermSymSymTri{T<:Real} = Union{RealHermSym{T}, SymTridiagonal{T}}
222224
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
223225
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
226+
const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}}
224227
const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}}
225228

229+
wrappertype(::Union{Symmetric, SymTridiagonal}) = Symmetric
230+
wrappertype(::Hermitian) = Hermitian
231+
226232
size(A::HermOrSym) = size(A.data)
227233
axes(A::HermOrSym) = axes(A.data)
228234
@inline function Base.isassigned(A::HermOrSym, i::Int, j::Int)
@@ -814,15 +820,15 @@ end
814820
^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
815821
^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p)
816822
^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p)
823+
function sympow(A::SymSymTri, p::Integer)
824+
if p < 0
825+
return Symmetric(Base.power_by_squaring(inv(A), -p))
826+
else
827+
return Symmetric(Base.power_by_squaring(A, p))
828+
end
829+
end
817830
for hermtype in (:Symmetric, :SymTridiagonal)
818831
@eval begin
819-
function sympow(A::$hermtype, p::Integer)
820-
if p < 0
821-
return Symmetric(Base.power_by_squaring(inv(A), -p))
822-
else
823-
return Symmetric(Base.power_by_squaring(A, p))
824-
end
825-
end
826832
function ^(A::$hermtype{<:Real}, p::Real)
827833
isinteger(p) && return integerpow(A, p)
828834
F = eigen(A)
@@ -844,8 +850,8 @@ function ^(A::Hermitian, p::Integer)
844850
else
845851
retmat = Base.power_by_squaring(A, p)
846852
end
847-
for i = 1:size(A,1)
848-
retmat[i,i] = real(retmat[i,i])
853+
for i in diagind(retmat, IndexStyle(retmat))
854+
retmat[i] = real(retmat[i])
849855
end
850856
return Hermitian(retmat)
851857
end
@@ -857,8 +863,8 @@ function ^(A::Hermitian{T}, p::Real) where T
857863
if T <: Real
858864
return Hermitian(retmat)
859865
else
860-
for i = 1:size(A,1)
861-
retmat[i,i] = real(retmat[i,i])
866+
for i in diagind(retmat, IndexStyle(retmat))
867+
retmat[i] = real(retmat[i])
862868
end
863869
return Hermitian(retmat)
864870
end
@@ -873,34 +879,25 @@ function ^(A::Hermitian{T}, p::Real) where T
873879
end
874880

875881
for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
876-
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
877-
@eval begin
878-
function ($func)(A::$hermtype{<:Real})
879-
F = eigen(A)
880-
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
881-
end
882-
end
883-
end
884882
@eval begin
883+
function ($func)(A::RealHermSymSymTri)
884+
F = eigen(A)
885+
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
886+
end
885887
function ($func)(A::Hermitian{<:Complex})
886-
n = checksquare(A)
887888
F = eigen(A)
888889
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
889-
for i = 1:n
890-
retmat[i,i] = real(retmat[i,i])
890+
for i in diagind(retmat, IndexStyle(retmat))
891+
retmat[i] = real(retmat[i])
891892
end
892893
return Hermitian(retmat)
893894
end
894895
end
895896
end
896897

897-
for wrapper in (:Symmetric, :Hermitian, :SymTridiagonal)
898-
@eval begin
899-
function cis(A::$wrapper{<:Real})
900-
F = eigen(A)
901-
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
902-
end
903-
end
898+
function cis(A::RealHermSymSymTri)
899+
F = eigen(A)
900+
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
904901
end
905902
function cis(A::Hermitian{<:Complex})
906903
F = eigen(A)
@@ -909,26 +906,21 @@ end
909906

910907

911908
for func in (:acos, :asin)
912-
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
913-
@eval begin
914-
function ($func)(A::$hermtype{<:Real})
915-
F = eigen(A)
916-
if all-> -1 λ 1, F.values)
917-
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
918-
else
919-
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
920-
end
909+
@eval begin
910+
function ($func)(A::RealHermSymSymTri)
911+
F = eigen(A)
912+
if all-> -1 λ 1, F.values)
913+
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
914+
else
915+
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
921916
end
922917
end
923-
end
924-
@eval begin
925918
function ($func)(A::Hermitian{<:Complex})
926-
n = checksquare(A)
927919
F = eigen(A)
928920
if all-> -1 λ 1, F.values)
929921
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
930-
for i = 1:n
931-
retmat[i,i] = real(retmat[i,i])
922+
for i in diagind(retmat, IndexStyle(retmat))
923+
retmat[i] = real(retmat[i])
932924
end
933925
return Hermitian(retmat)
934926
else
@@ -938,58 +930,49 @@ for func in (:acos, :asin)
938930
end
939931
end
940932

941-
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
942-
@eval begin
943-
function acosh(A::$hermtype{<:Real})
944-
F = eigen(A)
945-
if all-> λ 1, F.values)
946-
return $wrapper((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
947-
else
948-
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
949-
end
950-
end
933+
function acosh(A::RealHermSymSymTri)
934+
F = eigen(A)
935+
if all-> λ 1, F.values)
936+
return wrappertype(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
937+
else
938+
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
951939
end
952940
end
953941
function acosh(A::Hermitian{<:Complex})
954-
n = checksquare(A)
955942
F = eigen(A)
956943
if all-> λ 1, F.values)
957944
retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors'
958-
for i = 1:n
959-
retmat[i,i] = real(retmat[i,i])
945+
for i in diagind(retmat, IndexStyle(retmat))
946+
retmat[i] = real(retmat[i])
960947
end
961948
return Hermitian(retmat)
962949
else
963950
return (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
964951
end
965952
end
966953

967-
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
968-
@eval begin
969-
function sincos(A::$hermtype{<:Real})
970-
n = checksquare(A)
971-
F = eigen(A)
972-
T = float(eltype(F.values))
973-
S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
974-
for i in 1:n
975-
S.diag[i], C.diag[i] = sincos(F.values[i])
976-
end
977-
return $wrapper((F.vectors * S) * F.vectors'), $wrapper((F.vectors * C) * F.vectors')
978-
end
954+
function sincos(A::RealHermSymSymTri)
955+
n = checksquare(A)
956+
F = eigen(A)
957+
T = float(eltype(F.values))
958+
S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
959+
for i in eachindex(S.diag, C.diag, F.values)
960+
S.diag[i], C.diag[i] = sincos(F.values[i])
979961
end
962+
return wrappertype(A)((F.vectors * S) * F.vectors'), wrappertype(A)((F.vectors * C) * F.vectors')
980963
end
981964
function sincos(A::Hermitian{<:Complex})
982965
n = checksquare(A)
983966
F = eigen(A)
984967
T = float(eltype(F.values))
985968
S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
986-
for i in 1:n
969+
for i in eachindex(S.diag, C.diag, F.values)
987970
S.diag[i], C.diag[i] = sincos(F.values[i])
988971
end
989972
retmatS, retmatC = (F.vectors * S) * F.vectors', (F.vectors * C) * F.vectors'
990-
for i = 1:n
991-
retmatS[i,i] = real(retmatS[i,i])
992-
retmatC[i,i] = real(retmatC[i,i])
973+
for i in diagind(retmatS, IndexStyle(retmatS))
974+
retmatS[i] = real(retmatS[i])
975+
retmatC[i] = real(retmatC[i])
993976
end
994977
return Hermitian(retmatS), Hermitian(retmatC)
995978
end
@@ -999,28 +982,24 @@ for func in (:log, :sqrt)
999982
# sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors
1000983
rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[]
1001984
rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0
1002-
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
1003-
@eval begin
1004-
function ($func)(A::$hermtype{T}; $(rtolarg...)) where {T<:Real}
1005-
F = eigen(A)
1006-
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
1007-
if all-> λ λ₀, F.values)
1008-
return $wrapper((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
1009-
else
1010-
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
1011-
end
985+
@eval begin
986+
function ($func)(A::RealHermSymSymTri{T}; $(rtolarg...)) where {T<:Real}
987+
F = eigen(A)
988+
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
989+
if all-> λ λ₀, F.values)
990+
return wrappertype(A)((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
991+
else
992+
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
1012993
end
1013994
end
1014-
end
1015-
@eval begin
1016995
function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex}
1017996
n = checksquare(A)
1018997
F = eigen(A)
1019998
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
1020999
if all-> λ λ₀, F.values)
10211000
retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors'
1022-
for i = 1:n
1023-
retmat[i,i] = real(retmat[i,i])
1001+
for i in diagind(retmat, IndexStyle(retmat))
1002+
retmat[i] = real(retmat[i])
10241003
end
10251004
return Hermitian(retmat)
10261005
else

0 commit comments

Comments
 (0)