Skip to content

Commit 3337456

Browse files
committed
Instantiate unitary
1 parent cb8515e commit 3337456

File tree

3 files changed

+20
-28
lines changed

3 files changed

+20
-28
lines changed

test/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63)
5656
end
5757
if m == n
5858
AT = Diagonal{T, Vector{T}}
59-
TestSuite.test_svd(AT, m; test_trunc = !(T GenericFloats))
60-
TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),); test_trunc = !(T GenericFloats))
59+
TestSuite.test_svd(AT, m)
60+
TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),))
6161
end
6262
end
6363
end

test/testsuite/TestSuite.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ isrightcomplete(Vᴴ, Nᴴ) = Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I
7676
isrightcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isrightcomplete(collect(V), collect(N))
7777
isrightcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isrightcomplete(collect(V), collect(N))
7878

79+
instantiate_unitary(T, A, sz) = qr_compact(randn!(similar(A, eltype(T), sz, sz)))[1]
80+
instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A))))
81+
7982
include("qr.jl")
8083
include("lq.jl")
8184
include("polar.jl")

test/testsuite/svd.jl

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,7 @@ function test_svd_compact(
3131
Ac = deepcopy(A)
3232
m, n = size(A)
3333
minmn = min(m, n)
34-
if VERSION < v"1.11"
35-
# This is type unstable on older versions of Julia.
36-
U, S, Vᴴ = svd_compact(A)
37-
else
38-
U, S, Vᴴ = @testinferred svd_compact(A)
39-
end
34+
U, S, Vᴴ = @testinferred svd_compact(A)
4035
@test size(U) == (m, minmn)
4136
@test S isa Diagonal{real(eltype(T))} && size(S) == (minmn, minmn)
4237
@test size(Vᴴ) == (minmn, n)
@@ -68,12 +63,7 @@ function test_svd_compact_algs(
6863
Ac = deepcopy(A)
6964
m, n = size(A)
7065
minmn = min(m, n)
71-
if VERSION < v"1.11"
72-
# This is type unstable on older versions of Julia.
73-
U, S, Vᴴ = svd_compact(A; alg)
74-
else
75-
U, S, Vᴴ = @testinferred svd_compact(A; alg)
76-
end
66+
U, S, Vᴴ = @testinferred svd_compact(A; alg)
7767
@test size(U) == (m, minmn)
7868
@test S isa Diagonal{real(eltype(T))} && size(S) == (minmn, minmn)
7969
@test size(Vᴴ) == (minmn, n)
@@ -202,12 +192,12 @@ function test_svd_trunc(
202192

203193
@testset "mix maxrank and tol" begin
204194
m4 = 4
205-
U = qr_compact(randn!(similar(A, eltype(T), m4, m4)))[1]
195+
U = instantiate_unitary(T, A, m4)
206196
Sdiag = similar(A, real(eltype(T)), m4)
207197
copyto!(Sdiag, [0.9, 0.3, 0.1, 0.01])
208198
S = Diagonal(Sdiag)
209-
Vᴴ = qr_compact(randn!(similar(A, eltype(T), m4, m4)))[1]
210-
A = T <: Diagonal ? S : U * S * Vᴴ
199+
Vᴴ = instantiate_unitary(T, A, m4)
200+
A = U * S * Vᴴ
211201
for trunc_fun in (
212202
(rtol, maxrank) -> (; rtol, maxrank),
213203
(rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol),
@@ -224,12 +214,12 @@ function test_svd_trunc(
224214
@testset "specify truncation algorithm" begin
225215
atol = sqrt(eps(real(eltype(T))))
226216
m4 = 4
227-
U = qr_compact(randn!(similar(A, eltype(T), m4, m4)))[1]
217+
U = instantiate_unitary(T, A, m4)
228218
Sdiag = similar(A, real(eltype(T)), m4)
229-
copyto!(Sdiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01])
219+
copyto!(Sdiag, [0.9, 0.3, 0.1, 0.01])
220+
Vᴴ = instantiate_unitary(T, A, m4)
230221
S = Diagonal(Sdiag)
231-
Vᴴ = qr_compact(randn!(similar(A, eltype(T), m4, m4)))[1]
232-
A = T <: Diagonal ? S : U * S * Vᴴ
222+
A = U * S * Vᴴ
233223
alg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(; atol = 0.2))
234224
U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc(A; alg)
235225
@test diagview(S2) diagview(S)[1:2]
@@ -283,13 +273,12 @@ function test_svd_trunc_algs(
283273

284274
@testset "mix maxrank and tol" begin
285275
m4 = 4
286-
U = qr_compact(randn!(similar(A, eltype(T), m4, m4)))[1]
276+
U = instantiate_unitary(T, A, m4)
287277
Sdiag = similar(A, real(eltype(T)), m4)
288278
copyto!(Sdiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01])
289279
S = Diagonal(Sdiag)
290-
Vᴴ = qr_compact(randn!(similar(A, eltype(T), m4, m4)))[1]
291-
A = T <: Diagonal ? S : U * S * Vᴴ
292-
280+
Vᴴ = instantiate_unitary(T, A, m4)
281+
A = U * S * Vᴴ
293282
for trunc_fun in (
294283
(rtol, maxrank) -> (; rtol, maxrank),
295284
(rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol),
@@ -306,12 +295,12 @@ function test_svd_trunc_algs(
306295
@testset "specify truncation algorithm" begin
307296
atol = sqrt(eps(real(eltype(T))))
308297
m4 = 4
309-
U = qr_compact(randn!(similar(A, eltype(T), m4, m4)))[1]
298+
U = instantiate_unitary(T, A, m4)
310299
Sdiag = similar(A, real(eltype(T)), m4)
311300
copyto!(Sdiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01])
312301
S = Diagonal(Sdiag)
313-
Vᴴ = qr_compact(randn!(similar(A, eltype(T), m4, m4)))[1]
314-
A = T <: Diagonal ? S : U * S * Vᴴ
302+
Vᴴ = instantiate_unitary(T, A, m4)
303+
A = U * S * Vᴴ
315304
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = 0.2))
316305
U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc(A; alg = truncalg)
317306
@test diagview(S2) diagview(S)[1:2]

0 commit comments

Comments
 (0)