Skip to content

Commit 2915f28

Browse files
committed
Add flag for type of shifts in general eigensolver and use Wilkinson
double shifts by default. Add pertubation in case the trace of a 2x2 block is zero Fixes #27
1 parent cacaffc commit 2915f28

File tree

2 files changed

+51
-33
lines changed

2 files changed

+51
-33
lines changed

src/eigenGeneral.jl

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -70,75 +70,81 @@ module EigenGeneral
7070
R::Rotation
7171
end
7272

73-
function schurfact!{T<:Real}(H::HessenbergFactorization{T}; tol = eps(T), debug = false)
73+
function wilkinson(Hmm, t, d)
74+
λ1 = (t + sqrt(t*t - 4d))/2
75+
λ2 = (t - sqrt(t*t - 4d))/2
76+
return ifelse(abs(Hmm - λ1) < abs(Hmm - λ2), λ1, λ2)
77+
end
78+
79+
80+
function schurfact!{T<:Real}(H::HessenbergFactorization{T}; tol = eps(T), debug = false, shiftmethod = :Wilkinson, maxiter = 100*size(H, 1))
7481
n = size(H, 1)
7582
istart = 1
7683
iend = n
7784
HH = H.data
7885
τ = Rotation(Givens{T}[])
7986

87+
# iteration count
88+
i = 0
89+
8090
@inbounds while true
91+
i += 1
92+
if i > maxiter
93+
throw(ArgumentError("iteration limit $maxiter reached"))
94+
end
95+
8196
# Determine if the matrix splits. Find lowest positioned subdiagonal "zero"
8297
for istart = iend - 1:-1:1
83-
# debug && @printf("istart: %6d, iend %6d\n", istart, iend)
84-
# istart == minstart && break
8598
if abs(HH[istart + 1, istart]) < tol*(abs(HH[istart, istart]) + abs(HH[istart + 1, istart + 1]))
8699
istart += 1
87-
debug && @printf("Top deflation! Subdiagonal element is: %10.3e and istart now %6d\n", HH[istart, istart - 1], istart)
100+
debug && @printf("Split! Subdiagonal element is: %10.3e and istart now %6d\n", HH[istart, istart - 1], istart)
88101
break
89102
elseif istart > 1 && abs(HH[istart, istart - 1]) < tol*(abs(HH[istart - 1, istart - 1]) + abs(HH[istart, istart]))
90-
debug && @printf("Top deflation! Next subdiagonal element is: %10.3e and istart now %6d\n", HH[istart, istart - 1], istart)
103+
debug && @printf("Split! Next subdiagonal element is: %10.3e and istart now %6d\n", HH[istart, istart - 1], istart)
91104
break
92105
end
93106
end
94107

95108
# if block size is one we deflate
96109
if istart >= iend
110+
debug && @printf("Bottom deflation! Block size is one. New iend is %6d\n", iend - 1)
97111
iend -= 1
98112

99113
# and the same for a 2x2 block
100114
elseif istart + 1 == iend
115+
debug && @printf("Bottom deflation! Block size is two. New iend is %6d\n", iend - 2)
101116
iend -= 2
102117

103-
# if we don't deflate we'll run either a single or double shift bulge chase
118+
# run a QR iteration
119+
# shift method is specified with shiftmethod kw argument
104120
else
105121
Hmm = HH[iend, iend]
106122
Hm1m1 = HH[iend - 1, iend - 1]
107123
d = Hm1m1*Hmm - HH[iend, iend - 1]*HH[iend - 1, iend]
108124
t = Hm1m1 + Hmm
125+
t = iszero(t) ? eps(one(t)) : t # introduce a small pertubation for zero shifts
109126
debug && @printf("block start is: %6d, block end is: %6d, d: %10.3e, t: %10.3e\n", istart, iend, d, t)
110127

111-
# For small (sub) problems use Raleigh quotion shift and single shift
112-
if iend <= istart + 2
113-
σ = HH[iend, iend]
128+
if shiftmethod == :Wilkinson
129+
debug && @printf("Double shift with Wilkinson shift! Subdiagonal is: %10.3e, last subdiagonal is: %10.3e\n", HH[iend, iend - 1], HH[iend - 1, iend - 2])
114130

115131
# Run a bulge chase
116-
singleShiftQR!(HH, τ, σ, istart, iend)
117-
118-
# If the eigenvales of the 2x2 block are real use single shift
119-
elseif t*t > 4d
120-
debug && @printf("Single shift! subdiagonal is: %10.3e\n", HH[iend, iend - 1])
121-
122-
# Calculate the Wilkinson shift
123-
λ1 = (t + sqrt(t*t - 4d))/2
124-
λ2 = (t - sqrt(t*t - 4d))/2
125-
σ = abs(Hmm - λ1) < abs(Hmm - λ2) ? λ1 : λ2
132+
doubleShiftQR!(HH, τ, t, d, istart, iend)
133+
elseif shiftmethod == :Rayleigh
134+
debug && @printf("Single shift with Rayleigh shift! Subdiagonal is: %10.3e\n", HH[iend, iend - 1])
126135

127136
# Run a bulge chase
128-
singleShiftQR!(HH, τ, σ, istart, iend)
129-
130-
# else use double shift
137+
singleShiftQR!(HH, τ, Hmm, istart, iend)
131138
else
132-
debug && @printf("Double shift! subdiagonal is: %10.3e, last subdiagonal is: %10.3e\n", HH[iend, iend - 1], HH[iend - 1, iend - 2])
133-
doubleShiftQR!(HH, τ, t, d, istart, iend)
139+
throw(ArgumentError("only support supported shift methods are :Wilkinson (default) and :Rayleigh. You supplied $shiftmethod"))
134140
end
135141
end
136142
if iend <= 2 break end
137143
end
138144

139145
return Schur{T,typeof(HH)}(HH, τ)
140146
end
141-
schurfact!(A::StridedMatrix; tol = eps(float(real(eltype(A)))), debug = false) = schurfact!(hessfact!(A), tol = tol, debug = debug)
147+
schurfact!(A::StridedMatrix; kwargs...) = schurfact!(hessfact!(A); kwargs...)
142148

143149
function singleShiftQR!(HH::StridedMatrix, τ::Rotation, shift::Number, istart::Integer, iend::Integer)
144150
m = size(HH, 1)
@@ -172,16 +178,21 @@ module EigenGeneral
172178
H21 = HH[istart + 1, istart]
173179
Htmp11 = HH[istart + 2, istart]
174180
HH[istart + 2, istart] = 0
175-
Htmp21 = HH[istart + 3, istart]
176-
HH[istart + 3, istart] = 0
177-
Htmp22 = HH[istart + 3, istart + 1]
178-
HH[istart + 3, istart + 1] = 0
181+
if istart + 3 <= m
182+
Htmp21 = HH[istart + 3, istart]
183+
HH[istart + 3, istart] = 0
184+
Htmp22 = HH[istart + 3, istart + 1]
185+
HH[istart + 3, istart + 1] = 0
186+
else
187+
# values doen't matter in this case but variables should be initialized
188+
Htmp21 = Htmp22 = Htmp11
189+
end
179190
G1, r = givens(H11*H11 + HH[istart, istart + 1]*H21 - shiftTrace*H11 + shiftDeterminant, H21*(H11 + HH[istart + 1, istart + 1] - shiftTrace), istart, istart + 1)
180191
G2, _ = givens(r, H21*HH[istart + 2, istart + 1], istart, istart + 2)
181192
vHH = view(HH, :, istart:m)
182193
A_mul_B!(G1, vHH)
183194
A_mul_B!(G2, vHH)
184-
vHH = view(HH, 1:istart + 3, :)
195+
vHH = view(HH, 1:min(istart + 3, m), :)
185196
A_mul_Bc!(vHH, G1)
186197
A_mul_Bc!(vHH, G2)
187198
A_mul_B!(G1, τ)
@@ -209,9 +220,9 @@ module EigenGeneral
209220
return HH
210221
end
211222

212-
eigvals!(A::StridedMatrix; tol = eps(one(A[1])), debug = false) = eigvals!(schurfact!(A, tol = tol, debug = debug))
213-
eigvals!(H::HessenbergMatrix; tol = eps(one(A[1])), debug = false) = eigvals!(schurfact!(H, tol = tol, debug = debug))
214-
eigvals!(H::HessenbergFactorization; tol = eps(one(A[1])), debug = false) = eigvals!(schurfact!(H, tol = tol, debug = debug))
223+
eigvals!(A::StridedMatrix; kwargs...) = eigvals!(schurfact!(A; kwargs...))
224+
eigvals!(H::HessenbergMatrix; kwargs...) = eigvals!(schurfact!(H, kwargs...))
225+
eigvals!(H::HessenbergFactorization; kwargs...) = eigvals!(schurfact!(H, kwargs...))
215226

216227
function eigvals!{T}(S::Schur{T}; tol = eps(T))
217228
HH = S.data

test/eigengeneral.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,11 @@ using LinearAlgebra
1010
@test sort(imag(v1)) sort(imag(v2))
1111
@test sort(real(v1)) sort(real(map(Complex{Float64}, vBig)))
1212
@test sort(imag(v1)) sort(imag(map(Complex{Float64}, vBig)))
13+
end
14+
15+
@testset "make sure that solver doesn't hang" begin
16+
for i in 1:1000
17+
A = randn(8, 8)
18+
sort(abs.(LinearAlgebra.EigenGeneral.eigvals!(copy(A)))) sort(abs.(eigvals(A)))
19+
end
1320
end

0 commit comments

Comments
 (0)