diff --git a/docs/src/changelog.md b/docs/src/changelog.md index 331b691f..2e295231 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec ### Fixed +- Eigenvalue decompositions of diagonal inputs are sorted and have the same type as non-diagonal inputs ([#151](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/151) + ## [0.6.2](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.1...v0.6.2) - 2026-01-08 ### Added diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 5a0dd679..9b785898 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -30,23 +30,21 @@ end function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm) m, n = size(A) - @assert m == n && isdiag(A) + ((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected")) D, V = DV - @assert D isa Diagonal && V isa Diagonal + @assert D isa Diagonal && V isa AbstractMatrix @check_size(D, (m, m)) + @check_scalar(D, A, complex) @check_size(V, (m, m)) - # Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable - @check_scalar(D, A) - @check_scalar(V, A) + @check_scalar(V, A, complex) return nothing end function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm) m, n = size(A) - @assert m == n && isdiag(A) + ((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected")) @assert D isa AbstractVector @check_size(D, (n,)) - # Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable - @check_scalar(D, A) + @check_scalar(D, A, complex) return nothing end @@ -70,10 +68,14 @@ function initialize_output(::Union{typeof(eig_trunc!), typeof(eig_trunc_no_error end function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm) - return A, similar(A) + T = eltype(A) + Tc = complex(T) + D = T <: Complex ? A : Diagonal(similar(A, Tc, size(A, 1))) + return D, similar(A, Tc, size(A)) end function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithm) - return diagview(A) + T = eltype(A) + return T <: Complex ? diagview(A) : similar(A, complex(T), size(A, 1)) end # Implementation @@ -129,10 +131,21 @@ end # Diagonal logic # -------------- -function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal, Diagonal}, alg::DiagonalAlgorithm) - check_input(eig_full!, A, (D, V), alg) - D === A || copy!(D, A) - one!(V) +eig_sortby(x::T) where {T <: Number} = T <: Complex ? (real(x), imag(x)) : x +function eig_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) + check_input(eig_full!, A, DV, alg) + D, V = DV + diagA = diagview(A) + I = sortperm(diagA; by = eig_sortby) + if D === A + permute!(diagA, I) + else + diagview(D) .= view(diagA, I) + end + zero!(V) + n = size(A, 1) + I .+= (0:(n - 1)) .* n + V[I] .= Ref(one(eltype(V))) return D, V end @@ -140,6 +153,7 @@ function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm) check_input(eig_vals!, A, D, alg) Ad = diagview(A) D === Ad || copy!(D, Ad) + sort!(D; by = eig_sortby) return D end diff --git a/test/testsuite/eig.jl b/test/testsuite/eig.jl index 61ed1fc8..6ddfef5d 100644 --- a/test/testsuite/eig.jl +++ b/test/testsuite/eig.jl @@ -28,7 +28,7 @@ function test_eig_full( return @testset "eig_full! $summary_str" begin A = instantiate_matrix(T, sz) Ac = deepcopy(A) - Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T)) + Tc = complex(eltype(T)) D, V = @testinferred eig_full(A) @test eltype(D) == eltype(V) == Tc @test A * V ≈ V * D @@ -51,7 +51,7 @@ function test_eig_full_algs( return @testset "eig_full! algorithm $alg $summary_str" for alg in algs A = instantiate_matrix(T, sz) Ac = deepcopy(A) - Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T)) + Tc = complex(eltype(T)) D, V = @testinferred eig_full(A; alg) @test eltype(D) == eltype(V) == Tc @test A * V ≈ V * D