Skip to content

Commit 8ec0063

Browse files
authored
Sort diagonal eig values and standardize output types (#151)
* diagonal `eig` outputs sorted values * correct output types * correct sorting order * Fix typo * standardize `eig` output types * update changelog
1 parent 0f9327b commit 8ec0063

File tree

3 files changed

+32
-16
lines changed

3 files changed

+32
-16
lines changed

docs/src/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec
3030

3131
### Fixed
3232

33+
- Eigenvalue decompositions of diagonal inputs are sorted and have the same type as non-diagonal inputs ([#151](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/151)
34+
3335
## [0.6.2](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.1...v0.6.2) - 2026-01-08
3436

3537
### Added

src/implementations/eig.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,21 @@ end
3030

3131
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
3232
m, n = size(A)
33-
@assert m == n && isdiag(A)
33+
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
3434
D, V = DV
35-
@assert D isa Diagonal && V isa Diagonal
35+
@assert D isa Diagonal && V isa AbstractMatrix
3636
@check_size(D, (m, m))
37+
@check_scalar(D, A, complex)
3738
@check_size(V, (m, m))
38-
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
39-
@check_scalar(D, A)
40-
@check_scalar(V, A)
39+
@check_scalar(V, A, complex)
4140
return nothing
4241
end
4342
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
4443
m, n = size(A)
45-
@assert m == n && isdiag(A)
44+
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
4645
@assert D isa AbstractVector
4746
@check_size(D, (n,))
48-
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
49-
@check_scalar(D, A)
47+
@check_scalar(D, A, complex)
5048
return nothing
5149
end
5250

@@ -70,10 +68,14 @@ function initialize_output(::Union{typeof(eig_trunc!), typeof(eig_trunc_no_error
7068
end
7169

7270
function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm)
73-
return A, similar(A)
71+
T = eltype(A)
72+
Tc = complex(T)
73+
D = T <: Complex ? A : Diagonal(similar(A, Tc, size(A, 1)))
74+
return D, similar(A, Tc, size(A))
7475
end
7576
function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithm)
76-
return diagview(A)
77+
T = eltype(A)
78+
return T <: Complex ? diagview(A) : similar(A, complex(T), size(A, 1))
7779
end
7880

7981
# Implementation
@@ -129,17 +131,29 @@ end
129131

130132
# Diagonal logic
131133
# --------------
132-
function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal, Diagonal}, alg::DiagonalAlgorithm)
133-
check_input(eig_full!, A, (D, V), alg)
134-
D === A || copy!(D, A)
135-
one!(V)
134+
eig_sortby(x::T) where {T <: Number} = T <: Complex ? (real(x), imag(x)) : x
135+
function eig_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
136+
check_input(eig_full!, A, DV, alg)
137+
D, V = DV
138+
diagA = diagview(A)
139+
I = sortperm(diagA; by = eig_sortby)
140+
if D === A
141+
permute!(diagA, I)
142+
else
143+
diagview(D) .= view(diagA, I)
144+
end
145+
zero!(V)
146+
n = size(A, 1)
147+
I .+= (0:(n - 1)) .* n
148+
V[I] .= Ref(one(eltype(V)))
136149
return D, V
137150
end
138151

139152
function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm)
140153
check_input(eig_vals!, A, D, alg)
141154
Ad = diagview(A)
142155
D === Ad || copy!(D, Ad)
156+
sort!(D; by = eig_sortby)
143157
return D
144158
end
145159

test/testsuite/eig.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function test_eig_full(
2828
return @testset "eig_full! $summary_str" begin
2929
A = instantiate_matrix(T, sz)
3030
Ac = deepcopy(A)
31-
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))
31+
Tc = complex(eltype(T))
3232
D, V = @testinferred eig_full(A)
3333
@test eltype(D) == eltype(V) == Tc
3434
@test A * V V * D
@@ -51,7 +51,7 @@ function test_eig_full_algs(
5151
return @testset "eig_full! algorithm $alg $summary_str" for alg in algs
5252
A = instantiate_matrix(T, sz)
5353
Ac = deepcopy(A)
54-
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))
54+
Tc = complex(eltype(T))
5555
D, V = @testinferred eig_full(A; alg)
5656
@test eltype(D) == eltype(V) == Tc
5757
@test A * V V * D

0 commit comments

Comments
 (0)