Skip to content

Commit 4f5bcb1

Browse files
authored
Test truncated methods with GPU arrays (#142)
* Test truncated methods with GPU arrays
1 parent 3326e00 commit 4f5bcb1

File tree

6 files changed

+80
-72
lines changed

6 files changed

+80
-72
lines changed

src/implementations/eigh.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA
4343
@assert isdiag(A)
4444
m = size(A, 1)
4545
D, V = DV
46-
@assert D isa Diagonal && V isa Diagonal
46+
@assert D isa Diagonal
4747
@check_size(D, (m, m))
4848
@check_scalar(D, A, real)
4949
@check_size(V, (m, m))
@@ -79,7 +79,7 @@ function initialize_output(::Union{typeof(eigh_trunc!), typeof(eigh_trunc_no_err
7979
end
8080

8181
function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm)
82-
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A)
82+
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A, size(A)...)
8383
end
8484
function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm)
8585
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
@@ -146,15 +146,29 @@ end
146146
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
147147
check_input(eigh_full!, A, DV, alg)
148148
D, V = DV
149-
D === A || (diagview(D) .= real.(diagview(A)))
150-
one!(V)
149+
diagA = diagview(A)
150+
I = sortperm(diagA; by = real)
151+
if D === A
152+
permute!(diagA, I)
153+
else
154+
diagview(D) .= real.(view(diagA, I))
155+
end
156+
zero!(V)
157+
n = size(A, 1)
158+
I .+= (0:(n - 1)) .* n
159+
V[I] .= Ref(one(eltype(V)))
151160
return D, V
152161
end
153162

154163
function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm)
155164
check_input(eigh_vals!, A, D, alg)
156165
Ad = diagview(A)
157-
D === Ad || (D .= real.(Ad))
166+
if D === Ad
167+
sort!(Ad)
168+
else
169+
D .= real.(Ad)
170+
sort!(D)
171+
end
158172
return D
159173
end
160174

src/implementations/truncation.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ findtruncated(values::AbstractVector, ::NoTruncation) = Colon()
4949

5050
function findtruncated(values::AbstractVector, strategy::TruncationByOrder)
5151
howmany = min(strategy.howmany, length(values))
52-
return partialsortperm(values, 1:howmany; strategy.by, strategy.rev)
52+
return sortperm(values; strategy.by, strategy.rev)[1:howmany]
5353
end
5454
function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder)
5555
strategy.by === abs || return findtruncated(values, strategy)
@@ -96,14 +96,8 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real = 0, rtol::Real =
9696
# fast path to avoid checking all values
9797
ϵᵖ Nᵖ && return Base.OneTo(0)
9898

99-
truncerrᵖ = zero(real(eltype(values)))
100-
rank = length(values)
101-
for i in reverse(I)
102-
truncerrᵖ += by(values[i])
103-
truncerrᵖ ϵᵖ && break
104-
rank -= 1
105-
end
106-
99+
truncerrᵖ_array = cumsum(map(by, view(values, reverse(I))))
100+
rank = length(values) - (findfirst((ϵᵖ), truncerrᵖ_array) - 1)
107101
return Base.OneTo(rank)
108102
end
109103

test/eig.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ for T in (BLASFloats..., GenericFloats...)
1919
TestSuite.seed_rng!(123)
2020
if T BLASFloats
2121
if CUDA.functional()
22-
TestSuite.test_eig(CuMatrix{T}, (m, m); test_trunc = false)
23-
TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),); test_trunc = false)
24-
TestSuite.test_eig(Diagonal{T, CuVector{T}}, m; test_trunc = false)
25-
TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
22+
TestSuite.test_eig(CuMatrix{T}, (m, m))
23+
TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),))
24+
TestSuite.test_eig(Diagonal{T, CuVector{T}}, m)
25+
TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),))
2626
end
2727
#= not yet supported
2828
if AMDGPU.functional()

test/eigh.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ for T in (BLASFloats..., GenericFloats...)
2222
CUSOLVER_Jacobi(),
2323
CUSOLVER_DivideAndConquer(),
2424
)
25-
TestSuite.test_eigh(CuMatrix{T}, (m, m); test_trunc = false)
26-
TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS; test_trunc = false)
27-
TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m; test_trunc = false)
28-
TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
25+
TestSuite.test_eigh(CuMatrix{T}, (m, m))
26+
TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS)
27+
TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m)
28+
TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),))
2929
end
3030
if AMDGPU.functional()
3131
ROCSOLVER_EIGH_ALGS = (
@@ -34,6 +34,7 @@ for T in (BLASFloats..., GenericFloats...)
3434
ROCSOLVER_QRIteration(),
3535
ROCSOLVER_Bisection(),
3636
)
37+
# see https://github.com/JuliaGPU/AMDGPU.jl/issues/837
3738
TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false)
3839
TestSuite.test_eigh_algs(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false)
3940
TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_trunc = false)

test/testsuite/eig.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@ using MatrixAlgebraKit: TruncatedAlgorithm
33
using LinearAlgebra: I
44
using GenericSchur
55

6-
function test_eig(T::Type, sz; test_trunc = true, kwargs...)
6+
function test_eig(T::Type, sz; kwargs...)
77
summary_str = testargs_summary(T, sz)
88
return @testset "eig $summary_str" begin
99
test_eig_full(T, sz; kwargs...)
10-
test_trunc && test_eig_trunc(T, sz; kwargs...)
10+
test_eig_trunc(T, sz; kwargs...)
1111
end
1212
end
1313

14-
function test_eig_algs(T::Type, sz, algs; test_trunc = true, kwargs...)
14+
function test_eig_algs(T::Type, sz, algs; kwargs...)
1515
summary_str = testargs_summary(T, sz)
1616
return @testset "eig algorithms $summary_str" begin
1717
test_eig_full_algs(T, sz, algs; kwargs...)
18-
test_trunc && test_eig_trunc_algs(T, sz, algs; kwargs...)
18+
test_eig_trunc_algs(T, sz, algs; kwargs...)
1919
end
2020
end
2121

@@ -78,7 +78,7 @@ function test_eig_trunc(
7878
Ac = deepcopy(A)
7979
Tc = complex(eltype(T))
8080
# eigenvalues are sorted by ascending real component...
81-
D₀ = sort!(eig_vals(A); by = abs, rev = true)
81+
D₀ = collect(sort!(eig_vals(A); by = abs, rev = true))
8282
m = size(A, 1)
8383
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
8484
r = length(D₀) - rmin
@@ -150,7 +150,7 @@ function test_eig_trunc_algs(
150150
Ac = deepcopy(A)
151151
Tc = complex(eltype(T))
152152
# eigenvalues are sorted by ascending real component...
153-
D₀ = sort!(eig_vals(A; alg); by = abs, rev = true)
153+
D₀ = collect(sort!(eig_vals(A; alg); by = abs, rev = true))
154154
m = size(A, 1)
155155
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
156156
r = length(D₀) - rmin

test/testsuite/eigh.jl

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -76,50 +76,49 @@ function test_eigh_trunc(
7676
A = A * A'
7777
A = project_hermitian!(A)
7878
Ac = deepcopy(A)
79-
if !(T <: Diagonal)
8079
81-
m = size(A, 1)
82-
D₀ = reverse(eigh_vals(A))
83-
r = m - 2
84-
s = 1 + sqrt(eps(real(eltype(T))))
85-
atol = sqrt(eps(real(eltype(T))))
86-
# truncrank
87-
D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r))
88-
@test length(diagview(D1)) == r
89-
@test isisometric(V1)
90-
@test A * V1 ≈ V1 * D1
91-
@test opnorm(A - V1 * D1 * V1') D₀[r + 1]
92-
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
93-
94-
# trunctol
95-
trunc = trunctol(; atol = s * D₀[r + 1])
96-
D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc)
97-
@test length(diagview(D2)) == r
98-
@test isisometric(V2)
99-
@test A * V2 V2 * D2
100-
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
101-
102-
#truncerror
103-
s = 1 - sqrt(eps(real(eltype(T))))
104-
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
105-
D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc)
106-
@test length(diagview(D3)) == r
107-
@test A * V3 V3 * D3
108-
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
109-
110-
s = 1 - sqrt(eps(real(eltype(T))))
111-
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
112-
D4, V4 = @testinferred eigh_trunc_no_error(A; trunc)
113-
@test length(diagview(D4)) == r
114-
@test A * V4 V4 * D4
115-
116-
# test for same subspace
117-
@test V1 * (V1' * V2) ≈ V2
118-
@test V2 * (V2' * V1) V1
119-
@test V1 * (V1' * V3) ≈ V3
120-
@test V3 * (V3' * V1) V1
121-
@test V4 * (V4' * V1) ≈ V1
122-
end
80+
m = size(A, 1)
81+
D₀ = collect(reverse(eigh_vals(A)))
82+
r = m - 2
83+
s = 1 + sqrt(eps(real(eltype(T))))
84+
atol = sqrt(eps(real(eltype(T))))
85+
# truncrank
86+
D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r))
87+
@test length(diagview(D1)) == r
88+
@test isisometric(V1)
89+
@test A * V1 ≈ V1 * D1
90+
@test opnorm(A - V1 * D1 * V1') D₀[r + 1]
91+
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
92+
93+
# trunctol
94+
trunc = trunctol(; atol = s * D₀[r + 1])
95+
D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc)
96+
@test length(diagview(D2)) == r
97+
@test isisometric(V2)
98+
@test A * V2 V2 * D2
99+
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
100+
101+
#truncerror
102+
s = 1 - sqrt(eps(real(eltype(T))))
103+
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
104+
D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc)
105+
@test length(diagview(D3)) == r
106+
@test A * V3 V3 * D3
107+
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
108+
109+
s = 1 - sqrt(eps(real(eltype(T))))
110+
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
111+
D4, V4 = @testinferred eigh_trunc_no_error(A; trunc)
112+
@test length(diagview(D4)) == r
113+
@test A * V4 V4 * D4
114+
115+
# test for same subspace
116+
@test V1 * (V1' * V2) ≈ V2
117+
@test V2 * (V2' * V1) V1
118+
@test V1 * (V1' * V3) ≈ V3
119+
@test V3 * (V3' * V1) V1
120+
@test V4 * (V4' * V1) ≈ V1
121+
123122
@testset "specify truncation algorithm" begin
124123
atol = sqrt(eps(real(eltype(T))))
125124
m4 = 4
@@ -156,7 +155,7 @@ function test_eigh_trunc_algs(
156155
Ac = deepcopy(A)
157156
158157
m = size(A, 1)
159-
D₀ = reverse(eigh_vals(A))
158+
D₀ = collect(reverse(eigh_vals(A)))
160159
r = m - 2
161160
s = 1 + sqrt(eps(real(eltype(T))))
162161
# truncrank

0 commit comments

Comments
 (0)