Skip to content

Commit 86bf1b2

Browse files
dkarraschandreasnoack
authored andcommitted
Add svd for real Symmetric and Hermitian matrices (#32017)
* add svd for realsym/hermitian, and tests * shift improved code from svd.jl to symmetric.jl * minimize allocations * combine tests * move test out of loop * clean up svd tests
1 parent dbd04d4 commit 86bf1b2

File tree

2 files changed

+41
-31
lines changed

2 files changed

+41
-31
lines changed

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,22 @@ eigvals!(A::Hermitian{T,S}, B::Hermitian{T,S}) where {T<:BlasComplex,S<:StridedM
689689

690690
eigvecs(A::HermOrSym) = eigvecs(eigen(A))
691691

692+
function svd(A::RealHermSymComplexHerm, full::Bool=false)
693+
vals, vecs = eigen(A)
694+
I = sortperm(vals; by=abs, rev=true)
695+
permute!(vals, I)
696+
Base.permutecols!!(vecs, I) # left-singular vectors
697+
V = copy(vecs) # right-singular vectors
698+
# shifting -1 from singular values to right-singular vectors
699+
@inbounds for i = 1:length(vals)
700+
if vals[i] < 0
701+
vals[i] = -vals[i]
702+
for j = 1:size(V,1); V[j,i] = -V[j,i]; end
703+
end
704+
end
705+
return SVD(vecs, vals, V')
706+
end
707+
692708
function svdvals!(A::RealHermSymComplexHerm)
693709
vals = eigvals!(A)
694710
for i = 1:length(vals)

stdlib/LinearAlgebra/test/svd.jl

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,15 @@ end
3535

3636
n = 10
3737

38-
# Split n into 2 parts for tests needing two matrices
39-
n1 = div(n, 2)
40-
n2 = 2*n1
41-
4238
Random.seed!(1234321)
4339

4440
areal = randn(n,n)/2
4541
aimg = randn(n,n)/2
46-
a2real = randn(n,n)/2
47-
a2img = randn(n,n)/2
4842

4943
@testset for eltya in (Float32, Float64, ComplexF32, ComplexF64, Int)
5044
aa = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(areal, aimg) : areal)
51-
aa2 = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(a2real, a2img) : a2real)
5245
asym = aa' + aa # symmetric indefinite
53-
apd = aa' * aa # symmetric positive-definite
54-
for (a, a2) in ((aa, aa2), (view(aa, 1:n, 1:n), view(aa2, 1:n, 1:n)))
55-
ε = εa = eps(abs(float(one(eltya))))
56-
46+
for a in (aa, view(aa, 1:n, 1:n))
5747
usv = svd(a)
5848
@testset "singular value decomposition" begin
5949
@test usv.S === svdvals(usv)
@@ -72,28 +62,20 @@ a2img = randn(n,n)/2
7262
@test svdz.Vt Matrix{eltya}(I, 0, 0)
7363
end
7464
end
75-
usv = svd(a')
76-
@testset "singular value decomposition of adjoint" begin
77-
@test usv.S === svdvals(usv)
78-
@test usv.U * (Diagonal(usv.S) * usv.Vt) a'
79-
@test convert(Array, usv) a'
80-
@test usv.Vt' usv.V
81-
@test_throws ErrorException usv.Z
82-
b = rand(eltya,n)
83-
@test usv\b a'\b
84-
end
85-
usv = svd(transpose(a))
86-
@testset "singular value decomposition of transpose" begin
87-
@test usv.S === svdvals(usv)
88-
@test usv.U * (Diagonal(usv.S) * usv.Vt) transpose(a)
89-
@test convert(Array, usv) transpose(a)
90-
@test usv.Vt' usv.V
91-
@test_throws ErrorException usv.Z
92-
b = rand(eltya,n)
93-
@test usv\b transpose(a)\b
65+
@testset "singular value decomposition of adjoint/transpose" begin
66+
for transform in (adjoint, transpose)
67+
usv = svd(transform(a))
68+
@test usv.S === svdvals(usv)
69+
@test usv.U * (Diagonal(usv.S) * usv.Vt) transform(a)
70+
@test convert(Array, usv) transform(a)
71+
@test usv.Vt' usv.V
72+
@test_throws ErrorException usv.Z
73+
b = rand(eltya,n)
74+
@test usv\b transform(a)\b
75+
end
9476
end
9577
@testset "Generalized svd" begin
96-
a_svd = a[1:n1, :]
78+
a_svd = a[1:div(n, 2), :]
9779
gsvd = svd(a,a_svd)
9880
@test gsvd.U*gsvd.D1*gsvd.R*gsvd.Q' a
9981
@test gsvd.V*gsvd.D2*gsvd.R*gsvd.Q' a_svd
@@ -121,6 +103,18 @@ a2img = randn(n,n)/2
121103
@test gsvd.V*gsvd.D2*gsvd.R*gsvd.Q' c
122104
end
123105
end
106+
@testset "singular value decomposition of Hermitian/real-Symmetric" begin
107+
for T in (eltya <: Real ? (Symmetric, Hermitian) : (Hermitian,))
108+
usv = svd(T(asym))
109+
@test usv.S === svdvals(usv)
110+
@test usv.U * (Diagonal(usv.S) * usv.Vt) T(asym)
111+
@test convert(Array, usv) T(asym)
112+
@test usv.Vt' usv.V
113+
@test_throws ErrorException usv.Z
114+
b = rand(eltya,n)
115+
@test usv\b T(asym)\b
116+
end
117+
end
124118
if eltya <: LinearAlgebra.BlasReal
125119
@testset "Number input" begin
126120
x, y = randn(eltya, 2)

0 commit comments

Comments
 (0)