Skip to content

Commit d6825c7

Browse files
committed
Use TestSuite for projections tests
1 parent 9d1ffb8 commit d6825c7

File tree

6 files changed

+166
-309
lines changed

6 files changed

+166
-309
lines changed

test/amd/projections.jl

Lines changed: 0 additions & 104 deletions
This file was deleted.

test/cuda/projections.jl

Lines changed: 0 additions & 104 deletions
This file was deleted.

test/projections.jl

Lines changed: 21 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,35 @@
11
using MatrixAlgebraKit
2-
using MatrixAlgebraKit: check_hermitian, default_hermitian_tol
32
using Test
43
using TestExtras
54
using StableRNGs
65
using LinearAlgebra: LinearAlgebra, Diagonal, norm, normalize!
6+
using CUDA, AMDGPU
77

8-
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
8+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
910

10-
@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats
11-
rng = StableRNG(123)
12-
m = 54
13-
noisefactor = eps(real(T))^(3 / 4)
11+
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
12+
using .TestSuite
1413

15-
mat0 = zeros(T, (1, 1))
16-
@test ishermitian(mat0)
17-
@test ishermitian(mat0; atol = default_hermitian_tol(mat0))
18-
@test isnothing(check_hermitian(mat0))
14+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1915

20-
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
21-
for A in (randn(rng, T, m, m), Diagonal(randn(rng, T, m)))
22-
Ah = (A + A') / 2
23-
Aa = (A - A') / 2
24-
Ac = copy(A)
2516

26-
Bh = project_hermitian(A, alg)
27-
@test ishermitian(Bh)
28-
@test Bh Ah
29-
@test A == Ac
30-
Bh_approx = Bh + noisefactor * Aa
31-
# this is still hermitian for real Diagonal: |A - A'| == 0
32-
@test !ishermitian(Bh_approx) || norm(Aa) == 0
33-
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
34-
35-
Ba = project_antihermitian(A, alg)
36-
@test isantihermitian(Ba)
37-
@test Ba Aa
38-
@test A == Ac
39-
Ba_approx = Ba + noisefactor * Ah
40-
@test !isantihermitian(Ba_approx)
41-
# this is never anti-hermitian for real Diagonal: |A - A'| == 0
42-
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0
43-
44-
Bh = project_hermitian!(Ac, alg)
45-
@test Bh === Ac
46-
@test ishermitian(Bh)
47-
@test Bh Ah
48-
49-
copy!(Ac, A)
50-
Ba = project_antihermitian!(Ac, alg)
51-
@test Ba === Ac
52-
@test isantihermitian(Ba)
53-
@test Ba Aa
17+
m = 54
18+
for T in (BLASFloats..., GenericFloats...)
19+
TestSuite.seed_rng!(123)
20+
if T BLASFloats
21+
if CUDA.functional()
22+
TestSuite.test_projections(CuMatrix{T}, (m, m); test_blocksize = false)
23+
TestSuite.test_projections(Diagonal{T, CuVector{T}}, m; test_blocksize = false)
5424
end
55-
end
56-
57-
# test approximate error calculation
58-
A = normalize!(randn(rng, T, m, m))
59-
Ah = project_hermitian(A)
60-
Aa = project_antihermitian(A)
61-
62-
Ah_approx = Ah + noisefactor * Aa
63-
ϵ = norm(project_antihermitian(Ah_approx))
64-
@test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ)
65-
@test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ)
66-
67-
Aa_approx = Aa + noisefactor * Ah
68-
ϵ = norm(project_hermitian(Aa_approx))
69-
@test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ)
70-
@test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ)
71-
end
72-
73-
@testset "project_isometric! for T = $T" for T in BLASFloats
74-
rng = StableRNG(123)
75-
m = 54
76-
@testset "size ($m, $n)" for n in (37, m)
77-
k = min(m, n)
78-
if LinearAlgebra.LAPACK.version() < v"3.12.0"
79-
svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
80-
else
81-
svdalgs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi())
82-
end
83-
algs = (PolarViaSVD.(svdalgs)..., PolarNewton())
84-
@testset "algorithm $alg" for alg in algs
85-
A = randn(rng, T, m, n)
86-
W = project_isometric(A, alg)
87-
@test isisometric(W)
88-
W2 = project_isometric(W, alg)
89-
@test W2 W # stability of the projection
90-
@test W * (W' * A) ≈ A
91-
92-
Ac = similar(A)
93-
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
94-
@test W2 === W
95-
@test isisometric(W)
96-
97-
# test that W is closer to A then any other isometry
98-
for k in 1:10
99-
δA = randn(rng, T, m, n)
100-
W = project_isometric(A, alg)
101-
W2 = project_isometric(A + δA / 100, alg)
102-
@test norm(A - W2) > norm(A - W)
103-
end
25+
if AMDGPU.functional()
26+
TestSuite.test_projections(ROCMatrix{T}, (m, m); test_blocksize = false)
27+
TestSuite.test_projections(Diagonal{T, ROCVector{T}}, m; test_blocksize = false)
10428
end
10529
end
30+
if !is_buildkite
31+
TestSuite.test_projections(T, (m, m))
32+
AT = Diagonal{T, Vector{T}}
33+
TestSuite.test_projections(AT, m)
34+
end
10635
end

test/runtests.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ if !is_buildkite
77
@safetestset "Algorithms" begin
88
include("algorithms.jl")
99
end
10-
@safetestset "Projections" begin
11-
include("projections.jl")
12-
end
1310
@safetestset "Truncate" begin
1411
include("truncate.jl")
1512
end
@@ -71,12 +68,12 @@ end
7168
include("qr.jl")
7269
include("lq.jl")
7370
end
71+
@safetestset "Projections" begin
72+
include("projections.jl")
73+
end
7474

7575
using CUDA
7676
if CUDA.functional()
77-
@safetestset "CUDA Projections" begin
78-
include("cuda/projections.jl")
79-
end
8077
@safetestset "CUDA SVD" begin
8178
include("cuda/svd.jl")
8279
end
@@ -96,9 +93,6 @@ end
9693

9794
using AMDGPU
9895
if AMDGPU.functional()
99-
@safetestset "AMDGPU Projections" begin
100-
include("amd/projections.jl")
101-
end
10296
@safetestset "AMDGPU SVD" begin
10397
include("amd/svd.jl")
10498
end

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,6 @@ is_pivoted(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_pivoted(alg.qr_alg)
7171

7272
include("qr.jl")
7373
include("lq.jl")
74+
include("projections.jl")
7475

7576
end

0 commit comments

Comments
 (0)