Skip to content

Commit 367f329

Browse files
committed
standardize eig output types
1 parent d4af9fa commit 367f329

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

src/implementations/eig.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,21 @@ end
3030

3131
function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
3232
m, n = size(A)
33-
@assert m == n && isdiag(A)
33+
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
3434
D, V = DV
35-
@assert D isa Diagonal
35+
@assert D isa Diagonal && V isa AbstractMatrix
3636
@check_size(D, (m, m))
37+
@check_scalar(D, A, complex)
3738
@check_size(V, (m, m))
38-
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
39-
@check_scalar(D, A)
40-
@check_scalar(V, A, real)
39+
@check_scalar(V, A, complex)
4140
return nothing
4241
end
4342
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
4443
m, n = size(A)
45-
@assert m == n && isdiag(A)
44+
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
4645
@assert D isa AbstractVector
4746
@check_size(D, (n,))
48-
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
49-
@check_scalar(D, A)
47+
@check_scalar(D, A, complex)
5048
return nothing
5149
end
5250

@@ -70,10 +68,14 @@ function initialize_output(::Union{typeof(eig_trunc!), typeof(eig_trunc_no_error
7068
end
7169

7270
function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm)
73-
return A, similar(A, real(eltype(A)), size(A))
71+
T = eltype(A)
72+
Tc = complex(T)
73+
D = T <: Complex ? A : Diagonal(similar(A, Tc, size(A, 1)))
74+
return D, similar(A, Tc, size(A))
7475
end
7576
function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithm)
76-
return diagview(A)
77+
T = eltype(A)
78+
return T <: Complex ? diagview(A) : similar(A, complex(T), size(A, 1))
7779
end
7880

7981
# Implementation

test/testsuite/eig.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function test_eig_full(
2828
return @testset "eig_full! $summary_str" begin
2929
A = instantiate_matrix(T, sz)
3030
Ac = deepcopy(A)
31-
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))
31+
Tc = complex(eltype(T))
3232
D, V = @testinferred eig_full(A)
3333
@test eltype(D) == eltype(V) == Tc
3434
@test A * V V * D
@@ -51,7 +51,7 @@ function test_eig_full_algs(
5151
return @testset "eig_full! algorithm $alg $summary_str" for alg in algs
5252
A = instantiate_matrix(T, sz)
5353
Ac = deepcopy(A)
54-
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))
54+
Tc = complex(eltype(T))
5555
D, V = @testinferred eig_full(A; alg)
5656
@test eltype(D) == eltype(V) == Tc
5757
@test A * V V * D

0 commit comments

Comments
 (0)