Skip to content

Commit 6daeeb9

Browse files
committed
Add more coverage to general eigenvalue problem
1 parent 20a7736 commit 6daeeb9

File tree

3 files changed

+56
-37
lines changed

3 files changed

+56
-37
lines changed

src/eigenGeneral.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
11
using Printf
22
using LinearAlgebra
3-
using LinearAlgebra: Givens, Rotation
4-
5-
# Auxiliary
6-
function adiagmax(A::StridedMatrix)
7-
adm = zero(typeof(real(A[1])))
8-
@inbounds begin
9-
for i in size(A, 1)
10-
adm = max(adm, abs(A[i, i]))
11-
end
12-
end
13-
return adm
14-
end
3+
using LinearAlgebra: Givens, Rotation, givens
4+
5+
import Base: \
156

167
# Hessenberg Matrix
178
struct HessenbergMatrix{T,S<:StridedMatrix} <: AbstractMatrix{T}
@@ -30,19 +21,23 @@ function LinearAlgebra.ldiv!(H::HessenbergMatrix, B::AbstractVecOrMat)
3021
n = size(H, 1)
3122
Hd = H.data
3223
for i = 1:n-1
33-
G, _ = givens!(Hd, i, i + 1, i)
34-
lmul!(G, view(Hd, 1:n, i+1:n))
24+
G, _ = givens(Hd, i, i + 1, i)
25+
lmul!(G, view(Hd, 1:n, i:n))
3526
lmul!(G, B)
3627
end
37-
ldiv!(Triangular(Hd, :U), B)
28+
ldiv!(UpperTriangular(Hd), B)
3829
end
30+
(\)(H::HessenbergMatrix, B::AbstractVecOrMat) = ldiv!(copy(H), copy(B))
3931

4032
# Hessenberg factorization
4133
struct HessenbergFactorization{T,S<:StridedMatrix,U} <: Factorization{T}
4234
data::S
4335
τ::Vector{U}
4436
end
4537

38+
Base.copy(HF::HessenbergFactorization{T,S,U}) where {T,S,U} =
39+
HessenbergFactorization{T,S,U}(copy(HF.data), copy(HF.τ))
40+
4641
function _hessenberg!(A::StridedMatrix{T}) where {T}
4742
n = LinearAlgebra.checksquare(A)
4843
τ = Vector{Householder{T}}(undef, n - 1)
@@ -60,6 +55,14 @@ LinearAlgebra.hessenberg!(A::StridedMatrix) = _hessenberg!(A)
6055

6156
Base.size(H::HessenbergFactorization, args...) = size(H.data, args...)
6257

58+
function Base.getproperty(F::HessenbergFactorization, s::Symbol)
59+
if s === :H
60+
return HessenbergMatrix{eltype(F.data),typeof(F.data)}(F.data)
61+
else
62+
return getfield(F, s)
63+
end
64+
end
65+
6366
# Schur
6467
struct Schur{T,S<:StridedMatrix} <: Factorization{T}
6568
data::S
@@ -74,19 +77,12 @@ function Base.getproperty(F::Schur, s::Symbol)
7477
end
7578
end
7679

77-
function wilkinson(Hmm, t, d)
78-
λ1 = (t + sqrt(t * t - 4d)) / 2
79-
λ2 = (t - sqrt(t * t - 4d)) / 2
80-
return ifelse(abs(Hmm - λ1) < abs(Hmm - λ2), λ1, λ2)
81-
end
82-
8380
# We currently absorb extra unsupported keywords in kwargs. These could e.g. be scale and permute. Do we want to check that these are false?
8481
function _schur!(
8582
H::HessenbergFactorization{T};
8683
tol = eps(real(T)),
8784
shiftmethod = :Francis,
8885
maxiter = 30 * size(H, 1),
89-
kwargs...,
9086
) where {T}
9187

9288
n = size(H, 1)
@@ -176,6 +172,8 @@ function _schur!(
176172
return Schur{T,typeof(HH)}(HH, τ)
177173
end
178174
_schur!(A::StridedMatrix; kwargs...) = _schur!(_hessenberg!(A); kwargs...)
175+
176+
# FIXME! Move this method to piracy extension
179177
LinearAlgebra.schur!(A::StridedMatrix; kwargs...) = _schur!(A; kwargs...)
180178

181179
function singleShiftQR!(

test/eigengeneral.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,17 @@ using Test, GenericLinearAlgebra, LinearAlgebra
1414
@test sort(vGLA, by = cplxord) sort(vLAPACK, by = cplxord)
1515
@test sort(vGLA, by = cplxord) sort(complex(eltype(A)).(vBig), by = cplxord)
1616
@test issorted(vBig, by = cplxord)
17+
18+
if T <: Complex
19+
@testset "Rayleigh shifts" begin
20+
@test sort(
21+
GenericLinearAlgebra._eigvals!(
22+
GenericLinearAlgebra._schur!(copy(A), shiftmethod = :Rayleigh),
23+
),
24+
by = t -> (real(t), imag(t)),
25+
) sort(eigvals(A), by = t -> (real(t), imag(t)))
26+
end
27+
end
1728
end
1829

1930
@testset "make sure that solver doesn't hang" begin
@@ -235,6 +246,9 @@ using Test, GenericLinearAlgebra, LinearAlgebra
235246
A[:, (i+1):end] = A[:, (i+1):end] * HM'
236247
end
237248
@test tril(A, -2) zeros(n, n) atol = 1e-14
238-
end
239249

250+
@test eigvals(HF.H) eigvals(A)
251+
@test eigvals(HF.H) eigvals!(copy(HF))
252+
@test HF.H \ ones(n) Matrix(HF.H) \ ones(n)
253+
end
240254
end

test/lapack.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,34 +69,41 @@ using GenericLinearAlgebra.LAPACK2
6969
@test !(_vals sort(eigvals(T)))
7070
end
7171

72-
@testset "syevd: eltype=$eltype, uplo=$uplo" for eltype in (Float32, Float64, ComplexF32, ComplexF64), uplo in ('U', 'L')
72+
@testset "syevd: eltype=$eltype, uplo=$uplo" for eltype in (
73+
Float32,
74+
Float64,
75+
ComplexF32,
76+
ComplexF64,
77+
),
78+
uplo in ('U', 'L')
79+
7380
A = randn(eltype, n, n)
7481
A = A + A'
7582
if eltype <: Real
7683
vals, vecs = LAPACK2.syevd!('V', uplo, copy(A))
7784
else
7885
vals, vecs = LAPACK2.heevd!('V', uplo, copy(A))
7986
end
80-
@test diag(vecs'*A*vecs) eigvals(A)
87+
@test diag(vecs' * A * vecs) eigvals(A)
8188
end
8289

83-
@testset "tgevc: eltype=$eltype, side=$side, howmny=$howmny" for eltype in (Float32, Float64), side in ('L', 'R', 'B'), howmny in ('A', #='B', =#'S')
90+
@testset "tgevc: eltype=$eltype, side=$side, howmny=$howmny" for eltype in
91+
(Float32, Float64),
92+
side in ('L', 'R', 'B'),
93+
howmny in ('A', 'S')
94+
#='B', =#
8495
select = ones(Int, n)
8596
S, P = triu(randn(eltype, n, n)), triu(randn(eltype, n, n))
86-
VL, VR, m = LAPACK2.tgevc!(
87-
side,
88-
howmny,
89-
select,
90-
copy(S),
91-
copy(P),
92-
)
97+
VL, VR, m = LAPACK2.tgevc!(side, howmny, select, copy(S), copy(P))
9398
if side ('R', 'B')
94-
w = diag(S*VR) ./ diag(P*VR)
95-
@test S*VR P*VR*Diagonal(w) rtol=sqrt(eps(eltype)) atol=sqrt(eps(eltype))
99+
w = diag(S * VR) ./ diag(P * VR)
100+
@test S * VR P * VR * Diagonal(w) rtol = sqrt(eps(eltype)) atol =
101+
sqrt(eps(eltype))
96102
end
97103
if side ('L', 'B')
98-
w = w = diag(VL'*S) ./ diag(VL'*P)
99-
@test VL'*S Diagonal(w)*VL'*P rtol=sqrt(eps(eltype)) atol=sqrt(eps(eltype))
104+
w = w = diag(VL' * S) ./ diag(VL' * P)
105+
@test VL' * S Diagonal(w) * VL' * P rtol = sqrt(eps(eltype)) atol =
106+
sqrt(eps(eltype))
100107
end
101108
end
102109
end

0 commit comments

Comments
 (0)