Skip to content

Commit 6902e53

Browse files
committed
Try moving TestExtras around and adding projections
1 parent 3dd0ba9 commit 6902e53

File tree

9 files changed

+164
-94
lines changed

9 files changed

+164
-94
lines changed

test/lq.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using MatrixAlgebraKit
22
using Test
3-
using TestExtras
43
using StableRNGs
54
using LinearAlgebra: diag, I, Diagonal
65
using MatrixAlgebraKit: LQViaTransposedQR, LAPACK_HouseholderQR

test/polar.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using MatrixAlgebraKit
22
using Test
3-
using TestExtras
43
using StableRNGs
54
using LinearAlgebra: LinearAlgebra, I, isposdef
65

test/projections.jl

Lines changed: 20 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -4,96 +4,26 @@ using TestExtras
44
using StableRNGs
55
using LinearAlgebra: LinearAlgebra, Diagonal, norm, normalize!
66

7-
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
8-
9-
@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats
10-
rng = StableRNG(123)
11-
m = 54
12-
noisefactor = eps(real(T))^(3 / 4)
13-
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
14-
for A in (randn(rng, T, m, m), Diagonal(randn(rng, T, m)))
15-
Ah = (A + A') / 2
16-
Aa = (A - A') / 2
17-
Ac = copy(A)
18-
19-
Bh = project_hermitian(A, alg)
20-
@test ishermitian(Bh)
21-
@test Bh Ah
22-
@test A == Ac
23-
Bh_approx = Bh + noisefactor * Aa
24-
# this is still hermitian for real Diagonal: |A - A'| == 0
25-
@test !ishermitian(Bh_approx) || norm(Aa) == 0
26-
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
27-
28-
Ba = project_antihermitian(A, alg)
29-
@test isantihermitian(Ba)
30-
@test Ba Aa
31-
@test A == Ac
32-
Ba_approx = Ba + noisefactor * Ah
33-
@test !isantihermitian(Ba_approx)
34-
# this is never anti-hermitian for real Diagonal: |A - A'| == 0
35-
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0
36-
37-
Bh = project_hermitian!(Ac, alg)
38-
@test Bh === Ac
39-
@test ishermitian(Bh)
40-
@test Bh Ah
41-
42-
copy!(Ac, A)
43-
Ba = project_antihermitian!(Ac, alg)
44-
@test Ba === Ac
45-
@test isantihermitian(Ba)
46-
@test Ba Aa
47-
end
7+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
8+
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
9+
10+
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
11+
using .TestSuite
12+
13+
m = 54
14+
for T in BLASFloats, n in (37, m, 63)
15+
TestSuite.seed_rng!(123)
16+
TestSuite.test_projections(T, (m, n))
17+
if CUDA.functional()
18+
TestSuite.test_projections(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
19+
TestSuite.test_projections(Diagonal{T, CuVector{T}}, m; test_pivoted = false, test_blocksize = false)
4820
end
49-
50-
# test approximate error calculation
51-
A = normalize!(randn(rng, T, m, m))
52-
Ah = project_hermitian(A)
53-
Aa = project_antihermitian(A)
54-
55-
Ah_approx = Ah + noisefactor * Aa
56-
ϵ = norm(project_antihermitian(Ah_approx))
57-
@test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ)
58-
@test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ)
59-
60-
Aa_approx = Aa + noisefactor * Ah
61-
ϵ = norm(project_hermitian(Aa_approx))
62-
@test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ)
63-
@test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ)
64-
end
65-
66-
@testset "project_isometric! for T = $T" for T in BLASFloats
67-
rng = StableRNG(123)
68-
m = 54
69-
@testset "size ($m, $n)" for n in (37, m)
70-
k = min(m, n)
71-
if LinearAlgebra.LAPACK.version() < v"3.12.0"
72-
svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
73-
else
74-
svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi())
75-
end
76-
algs = (PolarViaSVD.(svdalgs)..., PolarNewton())
77-
@testset "algorithm $alg" for alg in algs
78-
A = randn(rng, T, m, n)
79-
W = project_isometric(A, alg)
80-
@test isisometric(W)
81-
W2 = project_isometric(W, alg)
82-
@test W2 W # stability of the projection
83-
@test W * (W' * A) ≈ A
84-
85-
Ac = similar(A)
86-
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
87-
@test W2 === W
88-
@test isisometric(W)
89-
90-
# test that W is closer to A then any other isometry
91-
for k in 1:10
92-
δA = randn(rng, T, m, n)
93-
W = project_isometric(A, alg)
94-
W2 = project_isometric(A + δA / 100, alg)
95-
@test norm(A - W2) > norm(A - W)
96-
end
97-
end
21+
if AMDGPU.functional()
22+
TestSuite.test_projections(ROCMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
23+
TestSuite.test_projections(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false)
9824
end
9925
end
26+
for T in (BLASFloats..., GenericFloats...)
27+
AT = Diagonal{T, Vector{T}}
28+
TestSuite.test_projections(AT, m; test_pivoted = false, test_blocksize = false)
29+
end

test/qr.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using MatrixAlgebraKit
22
using Test
3-
using TestExtras
43
using StableRNGs
54
using LinearAlgebra: diag, I, Diagonal
65
using CUDA, AMDGPU

test/testsuite/TestSuite.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Suite of tests that may be used for all packages inheriting from MatrixAlgebraKi
88
"""
99
module TestSuite
1010

11-
using Test, TestExtras
11+
using Test
1212
using MatrixAlgebraKit
1313
using MatrixAlgebraKit: diagview
1414
using LinearAlgebra: Diagonal, norm, istriu, istril
@@ -66,5 +66,6 @@ end
6666
include("qr.jl")
6767
include("lq.jl")
6868
include("polar.jl")
69+
include("projections.jl")
6970

7071
end

test/testsuite/lq.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using TestExtras
2+
13
function test_lq(T::Type, sz; kwargs...)
24
summary_str = testargs_summary(T, sz)
35
return @testset "lq $summary_str" begin

test/testsuite/polar.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using TestExtras
2+
13
function test_polar(T::Type, sz; kwargs...)
24
summary_str = testargs_summary(T, sz)
35
return @testset "polar $summary_str" begin

test/testsuite/projections.jl

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
using TestExtras
2+
3+
function test_projections(T::Type, sz; kwargs...)
4+
summary_str = testargs_summary(T, sz)
5+
return @testset "projections $summary_str" begin
6+
test_project_antihermitian(T, sz; kwargs...)
7+
test_project_hermitian(T, sz; kwargs...)
8+
test_project_isometric(T, sz; kwargs...)
9+
end
10+
end
11+
12+
function test_project_antihermitian(
13+
T::Type, sz;
14+
atol::Real = 0, rtol::Real = precision(T),
15+
kwargs...
16+
)
17+
summary_str = testargs_summary(T, sz)
18+
return @testset "project_antihermitian! $summary_str" begin
19+
noisefactor = eps(real(T))^(3 / 4)
20+
algs = (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
21+
@testset "algorithm $alg" for alg in algs
22+
A = instantiate_matrix(T, sz)
23+
Ac = deepcopy(A)
24+
Ah = (A + A') / 2
25+
Aa = (A - A') / 2
26+
27+
Ba = project_antihermitian(A, alg)
28+
@test isantihermitian(Ba)
29+
@test Ba Aa
30+
@test A == Ac
31+
Ba_approx = Ba + noisefactor * Ah
32+
@test !isantihermitian(Ba_approx)
33+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
34+
35+
copy!(Ac, A)
36+
Ba = project_antihermitian!(Ac, alg)
37+
@test Ba === Ac
38+
@test isantihermitian(Ba)
39+
@test Ba Aa
40+
end
41+
42+
# test approximate error calculation
43+
A = normalize!(randn(rng, T, m, m))
44+
Ah = project_hermitian(A)
45+
Aa = project_antihermitian(A)
46+
47+
Ah_approx = Ah + noisefactor * Aa
48+
ϵ = norm(project_antihermitian(Ah_approx))
49+
@test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ)
50+
@test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ)
51+
52+
Aa_approx = Aa + noisefactor * Ah
53+
ϵ = norm(project_hermitian(Aa_approx))
54+
@test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ)
55+
@test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ)
56+
end
57+
end
58+
59+
function test_project_hermitian(
60+
T::Type, sz;
61+
atol::Real = 0, rtol::Real = precision(T),
62+
kwargs...
63+
)
64+
summary_str = testargs_summary(T, sz)
65+
return @testset "project_hermitian! $summary_str" begin
66+
noisefactor = eps(real(T))^(3 / 4)
67+
algs = (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
68+
@testset "algorithm $alg" for alg in algs
69+
A = instantiate_matrix(T, sz)
70+
Ac = deepcopy(A)
71+
Ah = (A + A') / 2
72+
Aa = (A - A') / 2
73+
74+
Bh = project_hermitian(A, alg)
75+
@test ishermitian(Bh)
76+
@test Bh Ah
77+
@test A == Ac
78+
Bh_approx = Bh + noisefactor * Aa
79+
@test !ishermitian(Bh_approx)
80+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
81+
82+
Bh = project_hermitian!(Ac, alg)
83+
@test Bh === Ac
84+
@test ishermitian(Bh)
85+
@test Bh Ah
86+
end
87+
88+
# test approximate error calculation
89+
A = normalize!(randn(rng, T, m, m))
90+
Ah = project_hermitian(A)
91+
Aa = project_antihermitian(A)
92+
93+
Ah_approx = Ah + noisefactor * Aa
94+
ϵ = norm(project_antihermitian(Ah_approx))
95+
@test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ)
96+
@test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ)
97+
98+
Aa_approx = Aa + noisefactor * Ah
99+
ϵ = norm(project_hermitian(Aa_approx))
100+
@test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ)
101+
@test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ)
102+
end
103+
end
104+
105+
function test_project_isometric(
106+
T::Type, sz;
107+
atol::Real = 0, rtol::Real = precision(T),
108+
kwargs...
109+
)
110+
summary_str = testargs_summary(T, sz)
111+
return @testset "project_isometric! $summary_str" begin
112+
algs = (PolarViaSVD(), PolarNewton())
113+
@testset "algorithm $alg" for alg in algs
114+
A = instantiate_matrix(T, sz)
115+
Ac = deepcopy(A)
116+
k = min(size(A)...)
117+
W = project_isometric(A, alg)
118+
@test isisometric(W)
119+
W2 = project_isometric(W, alg)
120+
@test W2 W # stability of the projection
121+
@test W * (W' * A) ≈ A
122+
123+
W2 = @constinferred project_isometric!(Ac, W, alg)
124+
@test W2 === W
125+
@test isisometric(W)
126+
127+
# test that W is closer to A then any other isometry
128+
for k in 1:10
129+
δA = randn(rng, T, m, n)
130+
W = project_isometric(A, alg)
131+
W2 = project_isometric(A + δA / 100, alg)
132+
@test norm(A - W2) > norm(A - W)
133+
end
134+
end
135+
end
136+
end

test/testsuite/qr.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using TestExtras
2+
13
function test_qr(T::Type, sz; kwargs...)
24
summary_str = testargs_summary(T, sz)
35
return @testset "qr $summary_str" begin

0 commit comments

Comments
 (0)