@@ -3,124 +3,28 @@ using Test
33using TestExtras
44using StableRNGs
55using LinearAlgebra: LinearAlgebra, Diagonal, I
6- using MatrixAlgebraKit : TruncatedAlgorithm, diagview, norm
6+ using CUDA, AMDGPU
77
88BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
99GenericFloats = (Float16, BigFloat, Complex{BigFloat})
1010
11- @testset " eigh_full! for T = $T " for T in BLASFloats
12- rng = StableRNG(123 )
13- m = 54
14- for alg in (
15- LAPACK_MultipleRelativelyRobustRepresentations(),
16- LAPACK_DivideAndConquer(),
17- LAPACK_QRIteration(),
18- LAPACK_Bisection(),
19- )
20- A = randn(rng, T, m, m)
21- A = (A + A' ) / 2
11+ @isdefined(TestSuite) || include(" testsuite/TestSuite.jl" )
12+ using . TestSuite
2213
23- D, V = @constinferred eigh_full(A; alg)
24- @test A * V ≈ V * D
25- @test isunitary(V)
26- @test all(isreal, D)
27-
28- D2, V2 = eigh_full!(copy(A), (D, V), alg)
29- @test D2 === D
30- @test V2 === V
31-
32- D3 = @constinferred eigh_vals(A, alg)
33- @test D ≈ Diagonal(D3)
14+ m = 54
15+ for T in BLASFloats
16+ TestSuite. seed_rng!(123 )
17+ TestSuite. test_eigh(T, (m, m))
18+ if CUDA. functional()
19+ TestSuite. test_eigh(CuMatrix{T}, (m, m); test_blocksize = false )
20+ TestSuite. test_eigh(Diagonal{T, CuVector{T}}, m; test_blocksize = false )
3421 end
35- end
36-
37- @testset "eigh_trunc! for T = $T" for T in BLASFloats
38- rng = StableRNG(123)
39- m = 54
40- for alg in (
41- LAPACK_QRIteration(),
42- LAPACK_Bisection(),
43- LAPACK_DivideAndConquer(),
44- LAPACK_MultipleRelativelyRobustRepresentations(),
45- )
46- A = randn(rng, T, m, m)
47- A = A * A'
48- A = (A + A' ) / 2
49- Ac = similar(A)
50- D₀ = reverse(eigh_vals(A))
51- r = m - 2
52- s = 1 + sqrt(eps(real(T)))
53- atol = sqrt(eps(real(T)))
54-
55- D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r))
56- @test length(diagview(D1)) == r
57- @test isisometric(V1)
58- @test A * V1 ≈ V1 * D1
59- @test LinearAlgebra.opnorm(A - V1 * D1 * V1' ) ≈ D₀[r + 1 ]
60- @test ϵ1 ≈ norm(view(D₀, (r + 1 ): m)) atol = atol
61-
62- trunc = trunctol(; atol = s * D₀[r + 1 ])
63- D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc)
64- @test length(diagview(D2)) == r
65- @test isisometric(V2)
66- @test A * V2 ≈ V2 * D2
67- @test ϵ2 ≈ norm(view(D₀, (r + 1 ): m)) atol = atol
68-
69- s = 1 - sqrt(eps(real(T)))
70- trunc = truncerror(; atol = s * norm(@view(D₀[r: end ]), 1 ), p = 1 )
71- D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc)
72- @test length(diagview(D3)) == r
73- @test A * V3 ≈ V3 * D3
74- @test ϵ3 ≈ norm(view(D₀, (r + 1 ): m)) atol = atol
75-
76- # test for same subspace
77- @test V1 * (V1' * V2) ≈ V2
78- @test V2 * (V2' * V1) ≈ V1
79- @test V1 * (V1' * V3) ≈ V3
80- @test V3 * (V3' * V1) ≈ V1
22+ if AMDGPU. functional()
23+ TestSuite. test_eigh(ROCMatrix{T}, (m, m); test_blocksize = false )
24+ TestSuite. test_eigh(Diagonal{T, ROCVector{T}}, m; test_blocksize = false )
8125 end
8226end
83-
84- @testset " eigh_trunc! specify truncation algorithm T = $T " for T in BLASFloats
85- rng = StableRNG(123 )
86- m = 4
87- atol = sqrt(eps(real(T)))
88- V = qr_compact(randn(rng, T, m, m))[1 ]
89- D = Diagonal(real(T)[0.9 , 0.3 , 0.1 , 0.01 ])
90- A = V * D * V'
91- A = (A + A' ) / 2
92- alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2 ))
93- D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
94- @test diagview(D2) ≈ diagview(D)[1 : 2 ]
95- @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2 ))
96- @test ϵ2 ≈ norm(diagview(D)[3 : 4 ]) atol = atol
97-
98- alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2 ))
99- D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg)
100- @test diagview(D3) ≈ diagview(D)[1 : 2 ]
101- @test ϵ3 ≈ norm(diagview(D)[3 : 4 ]) atol = atol
102- end
103-
104- @testset " eigh for Diagonal{$T }" for T in (BLASFloats... , GenericFloats... )
105- rng = StableRNG(123 )
106- m = 54
107- Ad = randn(rng, T, m)
108- Ad .+ = conj.(Ad)
109- A = Diagonal(Ad)
110- atol = sqrt(eps(real(T)))
111-
112- D, V = @constinferred eigh_full(A)
113- @test D isa Diagonal{real(T)} && size(D) == size(A)
114- @test V isa Diagonal{T} && size(V) == size(A)
115- @test A * V ≈ V * D
116-
117- D2 = @constinferred eigh_vals(A)
118- @test D2 isa AbstractVector{real(T)} && length(D2) == m
119- @test diagview(D) ≈ D2
120-
121- A2 = Diagonal(T[0.9 , 0.3 , 0.1 , 0.01 ])
122- alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2 ))
123- D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg)
124- @test diagview(D2) ≈ diagview(A2)[1 : 2 ]
125- @test ϵ2 ≈ norm(diagview(A2)[3 : 4 ]) atol = atol
27+ for T in (BLASFloats... , GenericFloats... )
28+ AT = Diagonal{T, Vector{T}}
29+ TestSuite. test_eigh(AT, m; test_blocksize = false )
12630end
0 commit comments