Skip to content

Commit 5838094

Browse files
committed
Make it easier to dispatch to general eigensolver methods defined
in GenericLinearAlgebra by prefixing the computational methods with underscores.
1 parent 8e84679 commit 5838094

File tree

2 files changed

+37
-28
lines changed

2 files changed

+37
-28
lines changed

src/eigenGeneral.jl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ using Printf
22
using LinearAlgebra
33
using LinearAlgebra: Givens, Rotation
44

5-
import Base: copy, getindex, size
6-
import LinearAlgebra: eigvals!, ldiv!
7-
85
# Auxiliary
96
function adiagmax(A::StridedMatrix)
107
adm = zero(typeof(real(A[1])))
@@ -21,14 +18,14 @@ struct HessenbergMatrix{T,S<:StridedMatrix} <: AbstractMatrix{T}
2118
data::S
2219
end
2320

24-
copy(H::HessenbergMatrix{T,S}) where {T,S} = HessenbergMatrix{T,S}(copy(H.data))
21+
Base.copy(H::HessenbergMatrix{T,S}) where {T,S} = HessenbergMatrix{T,S}(copy(H.data))
2522

26-
getindex(H::HessenbergMatrix{T,S}, i::Integer, j::Integer) where {T,S} = i > j + 1 ? zero(T) : H.data[i,j]
23+
Base.getindex(H::HessenbergMatrix{T,S}, i::Integer, j::Integer) where {T,S} = i > j + 1 ? zero(T) : H.data[i,j]
2724

28-
size(H::HessenbergMatrix) = size(H.data)
29-
size(H::HessenbergMatrix, i::Integer) = size(H.data, i)
25+
Base.size(H::HessenbergMatrix) = size(H.data)
26+
Base.size(H::HessenbergMatrix, i::Integer) = size(H.data, i)
3027

31-
function ldiv!(H::HessenbergMatrix, B::AbstractVecOrMat)
28+
function LinearAlgebra.ldiv!(H::HessenbergMatrix, B::AbstractVecOrMat)
3229
n = size(H, 1)
3330
Hd = H.data
3431
for i = 1:n-1
@@ -45,7 +42,7 @@ struct HessenbergFactorization{T, S<:StridedMatrix,U} <: Factorization{T}
4542
τ::Vector{U}
4643
end
4744

48-
function hessfact!(A::StridedMatrix{T}) where T
45+
function _hessenberg!(A::StridedMatrix{T}) where T
4946
n = LinearAlgebra.checksquare(A)
5047
τ = Vector{Householder{T}}(undef, n - 1)
5148
for i = 1:n - 1
@@ -58,8 +55,9 @@ function hessfact!(A::StridedMatrix{T}) where T
5855
end
5956
return HessenbergFactorization{T, typeof(A), eltype(τ)}(A, τ)
6057
end
58+
LinearAlgebra.hessenberg!(A::StridedMatrix) = _hessenberg!(A)
6159

62-
size(H::HessenbergFactorization, args...) = size(H.data, args...)
60+
Base.size(H::HessenbergFactorization, args...) = size(H.data, args...)
6361

6462
# Schur
6563
struct Schur{T,S<:StridedMatrix} <: Factorization{T}
@@ -73,8 +71,8 @@ function wilkinson(Hmm, t, d)
7371
return ifelse(abs(Hmm - λ1) < abs(Hmm - λ2), λ1, λ2)
7472
end
7573

76-
77-
function schurfact!(H::HessenbergFactorization{T}; tol = eps(T), debug = false, shiftmethod = :Wilkinson, maxiter = 100*size(H, 1)) where T<:Real
74+
# We currently absorb extra unsupported keywords in kwargs. These could e.g. be scale and permute. Do we want to check that these are false?
75+
function _schur!(H::HessenbergFactorization{T}; tol = eps(T), debug = false, shiftmethod = :Wilkinson, maxiter = 100*size(H, 1), kwargs...) where T<:Real
7876
n = size(H, 1)
7977
istart = 1
8078
iend = n
@@ -143,7 +141,8 @@ function schurfact!(H::HessenbergFactorization{T}; tol = eps(T), debug = false,
143141

144142
return Schur{T,typeof(HH)}(HH, τ)
145143
end
146-
schurfact!(A::StridedMatrix; kwargs...) = schurfact!(hessfact!(A); kwargs...)
144+
_schur!(A::StridedMatrix; kwargs...) = _schur!(_hessenberg!(A); kwargs...)
145+
LinearAlgebra.schur!(A::StridedMatrix; kwargs...) = _schur!(A; kwargs...)
147146

148147
function singleShiftQR!(HH::StridedMatrix, τ::Rotation, shift::Number, istart::Integer, iend::Integer)
149148
m = size(HH, 1)
@@ -219,11 +218,16 @@ function doubleShiftQR!(HH::StridedMatrix, τ::Rotation, shiftTrace::Number, shi
219218
return HH
220219
end
221220

222-
eigvals!(A::StridedMatrix; kwargs...) = eigvals!(schurfact!(A; kwargs...))
223-
eigvals!(H::HessenbergMatrix; kwargs...) = eigvals!(schurfact!(H, kwargs...))
224-
eigvals!(H::HessenbergFactorization; kwargs...) = eigvals!(schurfact!(H, kwargs...))
221+
_eigvals!(A::StridedMatrix; kwargs...) = _eigvals!(_schur!(A; kwargs...))
222+
_eigvals!(H::HessenbergMatrix; kwargs...) = _eigvals!(_schur!(H, kwargs...))
223+
_eigvals!(H::HessenbergFactorization; kwargs...) = _eigvals!(_schur!(H, kwargs...))
225224

226-
function eigvals!(S::Schur{T}; tol = eps(T)) where T
225+
# Overload methods from LinearAlgebra to make them work generically
226+
LinearAlgebra.eigvals!(A::StridedMatrix; kwargs...) = _eigvals!(A; kwargs...)
227+
LinearAlgebra.eigvals!(H::HessenbergMatrix; kwargs...) = _eigvals!(H, kwargs...)
228+
LinearAlgebra.eigvals!(H::HessenbergFactorization; kwargs...) = _eigvals!(H, kwargs...)
229+
230+
function _eigvals!(S::Schur{T}; tol = eps(T)) where T
227231
HH = S.data
228232
n = size(HH, 1)
229233
vals = Vector{Complex{T}}(undef, n)
@@ -244,6 +248,8 @@ function eigvals!(S::Schur{T}; tol = eps(T)) where T
244248
i += 2
245249
end
246250
end
247-
if i == n vals[i] = HH[n, n] end
251+
if i == n
252+
vals[i] = HH[n, n]
253+
end
248254
return vals
249255
end

test/eigengeneral.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
using GenericLinearAlgebra
22

3-
@testset "General eigen problems" begin
4-
n = 10
3+
@testset "The General eigenvalue problem" begin
4+
5+
@testset "General eigen problem with n=$n" for n in (10, 100, 200)
56
A = randn(n,n)
6-
v1 = LinearAlgebra.eigvals!(copy(A))
7-
v2 = eigvals(A)
8-
vBig = LinearAlgebra.eigvals!(big.(A))
9-
@test sort(real(v1)) sort(real(v2))
10-
@test sort(imag(v1)) sort(imag(v2))
11-
@test sort(real(v1)) sort(real(map(Complex{Float64}, vBig)))
12-
@test sort(imag(v1)) sort(imag(map(Complex{Float64}, vBig)))
7+
vGLA = GenericLinearAlgebra._eigvals!(copy(A))
8+
vLAPACK = eigvals(A)
9+
vBig = eigvals(big.(A)) # not defined in LinearAlgebra so will dispatch to the version in GenericLinearAlgebra
10+
@test sort(real(vGLA)) sort(real(vLAPACK))
11+
@test sort(imag(vGLA)) sort(imag(vLAPACK))
12+
@test sort(real(vGLA)) sort(real(map(Complex{Float64}, vBig)))
13+
@test sort(imag(vGLA)) sort(imag(map(Complex{Float64}, vBig)))
1314
end
1415

1516
@testset "make sure that solver doesn't hang" begin
1617
for i in 1:1000
1718
A = randn(8, 8)
18-
sort(abs.(LinearAlgebra.eigvals!(copy(A)))) sort(abs.(eigvals(A)))
19+
sort(abs.(GenericLinearAlgebra._eigvals!(copy(A)))) sort(abs.(eigvals(A)))
1920
end
21+
end
22+
2023
end

0 commit comments

Comments
 (0)