@@ -8,109 +8,29 @@ using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
88BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
99GenericFloats = (Float16, BigFloat, Complex{BigFloat})
1010
11- @testset " eig_full! for T = $T " for T in BLASFloats
12- rng = StableRNG(123 )
13- m = 54
14- for alg in (LAPACK_Simple(), LAPACK_Expert(), :LAPACK_Simple, LAPACK_Simple)
15- A = randn(rng, T, m, m)
16- Tc = complex(T)
11+ using CUDA, AMDGPU
1712
18- D, V = @constinferred eig_full(A; alg = ($ alg))
19- @test eltype(D) == eltype(V) == Tc
20- @test A * V ≈ V * D
21-
22- alg′ = @constinferred MatrixAlgebraKit. select_algorithm(eig_full!, A, $ alg)
23-
24- Ac = similar(A)
25- D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′)
26- @test D2 === D
27- @test V2 === V
28- @test A * V ≈ V * D
29-
30- Dc = @constinferred eig_vals(A, alg′)
31- @test eltype(Dc) == Tc
32- @test D ≈ Diagonal(Dc)
33- end
34- end
35-
36- @testset " eig_trunc! for T = $T " for T in BLASFloats
37- rng = StableRNG(123 )
38- m = 54
39- for alg in (LAPACK_Simple(), LAPACK_Expert())
40- A = randn(rng, T, m, m)
41- A *= A' # TODO: deal with eigenvalue ordering etc
42- # eigenvalues are sorted by ascending real component...
43- D₀ = sort!(eig_vals(A); by = abs, rev = true)
44- rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
45- r = length(D₀) - rmin
46- atol = sqrt(eps(real(T)))
47-
48- D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r))
49- @test length(diagview(D1)) == r
50- @test A * V1 ≈ V1 * D1
51- @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol
52-
53- s = 1 + sqrt(eps(real(T)))
54- trunc = trunctol(; atol = s * abs(D₀[r + 1]))
55- D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc)
56- @test length(diagview(D2)) == r
57- @test A * V2 ≈ V2 * D2
58- @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol
59-
60- s = 1 - sqrt(eps(real(T)))
61- trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
62- D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc)
63- @test length(diagview(D3)) == r
64- @test A * V3 ≈ V3 * D3
65- @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol
66-
67- # trunctol keeps order, truncrank might not
68- # test for same subspace
69- @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2
70- @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1
71- @test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3
72- @test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1
13+ BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
14+ GenericFloats = (Float16,) # BigFloat, Complex{BigFloat})
15+
16+ @isdefined(TestSuite) || include(" testsuite/TestSuite.jl" )
17+ using . TestSuite
18+
19+ m = 54
20+ for T in BLASFloats
21+ TestSuite. seed_rng!(123 )
22+ TestSuite. test_eig(T, (m, m))
23+ if CUDA. functional()
24+ TestSuite. test_eig(CuMatrix{T}, (m, m); test_blocksize = false )
25+ TestSuite. test_eig(Diagonal{T, CuVector{T}}, m; test_blocksize = false )
7326 end
27+ #= not yet supported
28+ if AMDGPU.functional()
29+ TestSuite.test_eig(ROCMatrix{T}, (m, m); test_blocksize = false)
30+ TestSuite.test_eig(Diagonal{T, ROCVector{T}}, m; test_blocksize = false)
31+ end=#
7432end
75-
76- @testset "eig_trunc! specify truncation algorithm T = $T" for T in BLASFloats
77- rng = StableRNG(123)
78- m = 4
79- atol = sqrt(eps(real(T)))
80- V = randn(rng, T, m, m)
81- D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01])
82- A = V * D * inv(V)
83- alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
84- D2, V2, ϵ2 = @constinferred eig_trunc(A; alg)
85- @test diagview(D2) ≈ diagview(D)[1:2]
86- @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol
87- @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2))
88-
89- alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
90- D3, V3, ϵ3 = @constinferred eig_trunc(A; alg)
91- @test diagview(D3) ≈ diagview(D)[1:2]
92- @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol
93- end
94-
95- @testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
96- rng = StableRNG(123)
97- m = 54
98- Ad = randn(rng, T, m)
99- A = Diagonal(Ad)
100- atol = sqrt(eps(real(T)))
101-
102- D, V = @constinferred eig_full(A)
103- @test D isa Diagonal{T} && size(D) == size(A)
104- @test V isa Diagonal{T} && size(V) == size(A)
105- @test A * V ≈ V * D
106-
107- D2 = @constinferred eig_vals(A)
108- @test D2 isa AbstractVector{T} && length(D2) == m
109- @test diagview(D) ≈ D2
110-
111- A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
112- alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
113- D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg)
114- @test diagview(D2) ≈ diagview(A2)[1:2]
115- @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol
33+ for T in (BLASFloats... , GenericFloats... )
34+ AT = Diagonal{T, Vector{T}}
35+ TestSuite. test_eig(AT, m; test_blocksize = false )
11636end
0 commit comments