diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 40f2c557..19a190b1 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -43,7 +43,7 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalA @assert isdiag(A) m = size(A, 1) D, V = DV - @assert D isa Diagonal && V isa Diagonal + @assert D isa Diagonal @check_size(D, (m, m)) @check_scalar(D, A, real) @check_size(V, (m, m)) @@ -79,7 +79,7 @@ function initialize_output(::Union{typeof(eigh_trunc!), typeof(eigh_trunc_no_err end function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm) - return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A) + return eltype(A) <: Real ? A : similar(A, real(eltype(A))), similar(A, size(A)...) end function initialize_output(::typeof(eigh_vals!), A::Diagonal, ::DiagonalAlgorithm) return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) @@ -146,15 +146,29 @@ end function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) check_input(eigh_full!, A, DV, alg) D, V = DV - D === A || (diagview(D) .= real.(diagview(A))) - one!(V) + diagA = diagview(A) + I = sortperm(diagA; by = real) + if D === A + permute!(diagA, I) + else + diagview(D) .= real.(view(diagA, I)) + end + zero!(V) + n = size(A, 1) + I .+= (0:(n - 1)) .* n + V[I] .= Ref(one(eltype(V))) return D, V end function eigh_vals!(A::Diagonal, D, alg::DiagonalAlgorithm) check_input(eigh_vals!, A, D, alg) Ad = diagview(A) - D === Ad || (D .= real.(Ad)) + if D === Ad + sort!(Ad) + else + D .= real.(Ad) + sort!(D) + end return D end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 883c7759..be730bed 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -49,7 +49,7 @@ findtruncated(values::AbstractVector, ::NoTruncation) = Colon() function findtruncated(values::AbstractVector, strategy::TruncationByOrder) howmany = min(strategy.howmany, length(values)) - return partialsortperm(values, 1:howmany; strategy.by, strategy.rev) + return sortperm(values; strategy.by, strategy.rev)[1:howmany] end function findtruncated_svd(values::AbstractVector, strategy::TruncationByOrder) strategy.by === abs || return findtruncated(values, strategy) @@ -96,14 +96,8 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real = 0, rtol::Real = # fast path to avoid checking all values ϵᵖ ≥ Nᵖ && return Base.OneTo(0) - truncerrᵖ = zero(real(eltype(values))) - rank = length(values) - for i in reverse(I) - truncerrᵖ += by(values[i]) - truncerrᵖ ≥ ϵᵖ && break - rank -= 1 - end - + truncerrᵖ_array = cumsum(map(by, view(values, reverse(I)))) + rank = length(values) - (findfirst(≥(ϵᵖ), truncerrᵖ_array) - 1) return Base.OneTo(rank) end diff --git a/test/eig.jl b/test/eig.jl index 7cc54c5d..df6fdf86 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -19,10 +19,10 @@ for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(123) if T ∈ BLASFloats if CUDA.functional() - TestSuite.test_eig(CuMatrix{T}, (m, m); test_trunc = false) - TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),); test_trunc = false) - TestSuite.test_eig(Diagonal{T, CuVector{T}}, m; test_trunc = false) - TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) + TestSuite.test_eig(CuMatrix{T}, (m, m)) + TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),)) + TestSuite.test_eig(Diagonal{T, CuVector{T}}, m) + TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) end #= not yet supported if AMDGPU.functional() diff --git a/test/eigh.jl b/test/eigh.jl index 8766ccc0..2efb4e15 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -22,10 +22,10 @@ for T in (BLASFloats..., GenericFloats...) CUSOLVER_Jacobi(), CUSOLVER_DivideAndConquer(), ) - TestSuite.test_eigh(CuMatrix{T}, (m, m); test_trunc = false) - TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS; test_trunc = false) - TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m; test_trunc = false) - TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) + TestSuite.test_eigh(CuMatrix{T}, (m, m)) + TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS) + TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m) + TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) end if AMDGPU.functional() ROCSOLVER_EIGH_ALGS = ( @@ -34,6 +34,7 @@ for T in (BLASFloats..., GenericFloats...) ROCSOLVER_QRIteration(), ROCSOLVER_Bisection(), ) + # see https://github.com/JuliaGPU/AMDGPU.jl/issues/837 TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false) TestSuite.test_eigh_algs(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false) TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_trunc = false) diff --git a/test/testsuite/eig.jl b/test/testsuite/eig.jl index 2dbea8b9..61ed1fc8 100644 --- a/test/testsuite/eig.jl +++ b/test/testsuite/eig.jl @@ -3,19 +3,19 @@ using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra: I using GenericSchur -function test_eig(T::Type, sz; test_trunc = true, kwargs...) +function test_eig(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "eig $summary_str" begin test_eig_full(T, sz; kwargs...) - test_trunc && test_eig_trunc(T, sz; kwargs...) + test_eig_trunc(T, sz; kwargs...) end end -function test_eig_algs(T::Type, sz, algs; test_trunc = true, kwargs...) +function test_eig_algs(T::Type, sz, algs; kwargs...) summary_str = testargs_summary(T, sz) return @testset "eig algorithms $summary_str" begin test_eig_full_algs(T, sz, algs; kwargs...) - test_trunc && test_eig_trunc_algs(T, sz, algs; kwargs...) + test_eig_trunc_algs(T, sz, algs; kwargs...) end end @@ -78,7 +78,7 @@ function test_eig_trunc( Ac = deepcopy(A) Tc = complex(eltype(T)) # eigenvalues are sorted by ascending real component... - D₀ = sort!(eig_vals(A); by = abs, rev = true) + D₀ = collect(sort!(eig_vals(A); by = abs, rev = true)) m = size(A, 1) rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin @@ -150,7 +150,7 @@ function test_eig_trunc_algs( Ac = deepcopy(A) Tc = complex(eltype(T)) # eigenvalues are sorted by ascending real component... - D₀ = sort!(eig_vals(A; alg); by = abs, rev = true) + D₀ = collect(sort!(eig_vals(A; alg); by = abs, rev = true)) m = size(A, 1) rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl index df6e4d6e..087ee734 100644 --- a/test/testsuite/eigh.jl +++ b/test/testsuite/eigh.jl @@ -76,50 +76,49 @@ function test_eigh_trunc( A = A * A' A = project_hermitian!(A) Ac = deepcopy(A) - if !(T <: Diagonal) - m = size(A, 1) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(eltype(T)))) - atol = sqrt(eps(real(eltype(T)))) - # truncrank - D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # trunctol - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - #truncerror - s = 1 - sqrt(eps(real(eltype(T)))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(eltype(T)))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D4, V4 = @testinferred eigh_trunc_no_error(A; trunc) - @test length(diagview(D4)) == r - @test A * V4 ≈ V4 * D4 - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 - @test V4 * (V4' * V1) ≈ V1 - end + m = size(A, 1) + D₀ = collect(reverse(eigh_vals(A))) + r = m - 2 + s = 1 + sqrt(eps(real(eltype(T)))) + atol = sqrt(eps(real(eltype(T)))) + # truncrank + D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r)) + @test length(diagview(D1)) == r + @test isisometric(V1) + @test A * V1 ≈ V1 * D1 + @test opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + # trunctol + trunc = trunctol(; atol = s * D₀[r + 1]) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D2)) == r + @test isisometric(V2) + @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + #truncerror + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D3)) == r + @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D4, V4 = @testinferred eigh_trunc_no_error(A; trunc) + @test length(diagview(D4)) == r + @test A * V4 ≈ V4 * D4 + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + @test V1 * (V1' * V3) ≈ V3 + @test V3 * (V3' * V1) ≈ V1 + @test V4 * (V4' * V1) ≈ V1 + @testset "specify truncation algorithm" begin atol = sqrt(eps(real(eltype(T)))) m4 = 4 @@ -156,7 +155,7 @@ function test_eigh_trunc_algs( Ac = deepcopy(A) m = size(A, 1) - D₀ = reverse(eigh_vals(A)) + D₀ = collect(reverse(eigh_vals(A))) r = m - 2 s = 1 + sqrt(eps(real(eltype(T)))) # truncrank