Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA
@assert isdiag(A)
m = size(A, 1)
D, V = DV
@assert D isa Diagonal && V isa Diagonal
@assert D isa Diagonal
@check_size(D, (m, m))
@check_scalar(D, A, real)
@check_size(V, (m, m))
Expand Down Expand Up @@ -79,7 +79,7 @@ function initialize_output(::Union{typeof(eigh_trunc!), typeof(eigh_trunc_no_err
end

function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm)
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A)
return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A, size(A)...)
end
function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm)
return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1))
Expand Down Expand Up @@ -146,15 +146,29 @@ end
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
check_input(eigh_full!, A, DV, alg)
D, V = DV
D === A || (diagview(D) .= real.(diagview(A)))
one!(V)
diagA = diagview(A)
I = sortperm(diagA; by = real)
if D === A
permute!(diagA, I)
else
diagview(D) .= real.(view(diagA, I))
end
zero!(V)
n = size(A, 1)
I .+= (0:(n - 1)) .* n
V[I] .= Ref(one(eltype(V)))
return D, V
end

function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm)
check_input(eigh_vals!, A, D, alg)
Ad = diagview(A)
D === Ad || (D .= real.(Ad))
if D === Ad
sort!(Ad)
else
D .= real.(Ad)
sort!(D)
end
return D
end

Expand Down
12 changes: 3 additions & 9 deletions src/implementations/truncation.jl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would slightly prefer to have these changes as specializations in the GPU extensions, but definitely happy to try them out like this

Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ findtruncated(values::AbstractVector, ::NoTruncation) = Colon()

function findtruncated(values::AbstractVector, strategy::TruncationByOrder)
howmany = min(strategy.howmany, length(values))
return partialsortperm(values, 1:howmany; strategy.by, strategy.rev)
return sortperm(values; strategy.by, strategy.rev)[1:howmany]
end
function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder)
strategy.by === abs || return findtruncated(values, strategy)
Expand Down Expand Up @@ -96,14 +96,8 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real = 0, rtol::Real =
# fast path to avoid checking all values
ϵᵖ ≥ Nᵖ && return Base.OneTo(0)

truncerrᵖ = zero(real(eltype(values)))
rank = length(values)
for i in reverse(I)
truncerrᵖ += by(values[i])
truncerrᵖ ≥ ϵᵖ && break
rank -= 1
end

truncerrᵖ_array = cumsum(map(by, view(values, reverse(I))))
rank = length(values) - (findfirst(≥(ϵᵖ), truncerrᵖ_array) - 1)
return Base.OneTo(rank)
end

Expand Down
8 changes: 4 additions & 4 deletions test/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if T ∈ BLASFloats
if CUDA.functional()
TestSuite.test_eig(CuMatrix{T}, (m, m); test_trunc = false)
TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),); test_trunc = false)
TestSuite.test_eig(Diagonal{T, CuVector{T}}, m; test_trunc = false)
TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
TestSuite.test_eig(CuMatrix{T}, (m, m))
TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),))
TestSuite.test_eig(Diagonal{T, CuVector{T}}, m)
TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),))
end
#= not yet supported
if AMDGPU.functional()
Expand Down
9 changes: 5 additions & 4 deletions test/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ for T in (BLASFloats..., GenericFloats...)
CUSOLVER_Jacobi(),
CUSOLVER_DivideAndConquer(),
)
TestSuite.test_eigh(CuMatrix{T}, (m, m); test_trunc = false)
TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS; test_trunc = false)
TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m; test_trunc = false)
TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
TestSuite.test_eigh(CuMatrix{T}, (m, m))
TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS)
TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m)
TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),))
end
if AMDGPU.functional()
ROCSOLVER_EIGH_ALGS = (
Expand All @@ -34,6 +34,7 @@ for T in (BLASFloats..., GenericFloats...)
ROCSOLVER_QRIteration(),
ROCSOLVER_Bisection(),
)
# see https://github.com/JuliaGPU/AMDGPU.jl/issues/837
TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false)
TestSuite.test_eigh_algs(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false)
TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_trunc = false)
Expand Down
12 changes: 6 additions & 6 deletions test/testsuite/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ using MatrixAlgebraKit: TruncatedAlgorithm
using LinearAlgebra: I
using GenericSchur

function test_eig(T::Type, sz; test_trunc = true, kwargs...)
function test_eig(T::Type, sz; kwargs...)
summary_str = testargs_summary(T, sz)
return @testset "eig $summary_str" begin
test_eig_full(T, sz; kwargs...)
test_trunc && test_eig_trunc(T, sz; kwargs...)
test_eig_trunc(T, sz; kwargs...)
end
end

function test_eig_algs(T::Type, sz, algs; test_trunc = true, kwargs...)
function test_eig_algs(T::Type, sz, algs; kwargs...)
summary_str = testargs_summary(T, sz)
return @testset "eig algorithms $summary_str" begin
test_eig_full_algs(T, sz, algs; kwargs...)
test_trunc && test_eig_trunc_algs(T, sz, algs; kwargs...)
test_eig_trunc_algs(T, sz, algs; kwargs...)
end
end

Expand Down Expand Up @@ -78,7 +78,7 @@ function test_eig_trunc(
Ac = deepcopy(A)
Tc = complex(eltype(T))
# eigenvalues are sorted by ascending real component...
D₀ = sort!(eig_vals(A); by = abs, rev = true)
D₀ = collect(sort!(eig_vals(A); by = abs, rev = true))
m = size(A, 1)
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
r = length(D₀) - rmin
Expand Down Expand Up @@ -150,7 +150,7 @@ function test_eig_trunc_algs(
Ac = deepcopy(A)
Tc = complex(eltype(T))
# eigenvalues are sorted by ascending real component...
D₀ = sort!(eig_vals(A; alg); by = abs, rev = true)
D₀ = collect(sort!(eig_vals(A; alg); by = abs, rev = true))
m = size(A, 1)
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
r = length(D₀) - rmin
Expand Down
87 changes: 43 additions & 44 deletions test/testsuite/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,50 +76,49 @@ function test_eigh_trunc(
A = A * A'
A = project_hermitian!(A)
Ac = deepcopy(A)
if !(T <: Diagonal)

m = size(A, 1)
D₀ = reverse(eigh_vals(A))
r = m - 2
s = 1 + sqrt(eps(real(eltype(T))))
atol = sqrt(eps(real(eltype(T))))
# truncrank
D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r))
@test length(diagview(D1)) == r
@test isisometric(V1)
@test A * V1 ≈ V1 * D1
@test opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1]
@test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol

# trunctol
trunc = trunctol(; atol = s * D₀[r + 1])
D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc)
@test length(diagview(D2)) == r
@test isisometric(V2)
@test A * V2 ≈ V2 * D2
@test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol

#truncerror
s = 1 - sqrt(eps(real(eltype(T))))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc)
@test length(diagview(D3)) == r
@test A * V3 ≈ V3 * D3
@test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(eltype(T))))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D4, V4 = @testinferred eigh_trunc_no_error(A; trunc)
@test length(diagview(D4)) == r
@test A * V4 ≈ V4 * D4

# test for same subspace
@test V1 * (V1' * V2) ≈ V2
@test V2 * (V2' * V1) ≈ V1
@test V1 * (V1' * V3) ≈ V3
@test V3 * (V3' * V1) ≈ V1
@test V4 * (V4' * V1) ≈ V1
end
m = size(A, 1)
D₀ = collect(reverse(eigh_vals(A)))
r = m - 2
s = 1 + sqrt(eps(real(eltype(T))))
atol = sqrt(eps(real(eltype(T))))
# truncrank
D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r))
@test length(diagview(D1)) == r
@test isisometric(V1)
@test A * V1 ≈ V1 * D1
@test opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1]
@test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol

# trunctol
trunc = trunctol(; atol = s * D₀[r + 1])
D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc)
@test length(diagview(D2)) == r
@test isisometric(V2)
@test A * V2 ≈ V2 * D2
@test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol

#truncerror
s = 1 - sqrt(eps(real(eltype(T))))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc)
@test length(diagview(D3)) == r
@test A * V3 ≈ V3 * D3
@test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(eltype(T))))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D4, V4 = @testinferred eigh_trunc_no_error(A; trunc)
@test length(diagview(D4)) == r
@test A * V4 ≈ V4 * D4

# test for same subspace
@test V1 * (V1' * V2) ≈ V2
@test V2 * (V2' * V1) ≈ V1
@test V1 * (V1' * V3) ≈ V3
@test V3 * (V3' * V1) ≈ V1
@test V4 * (V4' * V1) ≈ V1

@testset "specify truncation algorithm" begin
atol = sqrt(eps(real(eltype(T))))
m4 = 4
Expand Down Expand Up @@ -156,7 +155,7 @@ function test_eigh_trunc_algs(
Ac = deepcopy(A)

m = size(A, 1)
D₀ = reverse(eigh_vals(A))
D₀ = collect(reverse(eigh_vals(A)))
r = m - 2
s = 1 + sqrt(eps(real(eltype(T))))
# truncrank
Expand Down