Skip to content

Commit 1c6085f

Browse files
haampieandreasnoack
authored andcommitted
Chebyshev as iterator (#145)
* Improve tests: with preconditioner and a more sane eigenspectrum for the SPD matrix (no eigenvalues close to 0 to improve convergence) * Use the iterator idea
1 parent d48a419 commit 1c6085f

File tree

2 files changed

+144
-65
lines changed

2 files changed

+144
-65
lines changed

src/chebyshev.jl

Lines changed: 113 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,129 @@
1+
import Base: next, start, done
2+
13
export chebyshev, chebyshev!
24

5+
type ChebyshevIterable{precT, matT, vecT, realT <: Real}
6+
Pl::precT
7+
A::matT
8+
b::vecT
9+
10+
x::vecT
11+
r::vecT
12+
u::vecT
13+
c::vecT
14+
15+
α::realT
16+
17+
λ_avg::realT
18+
λ_diff::realT
19+
20+
resnorm::realT
21+
reltol::realT
22+
maxiter::Int
23+
mv_products::Int
24+
end
25+
26+
converged(c::ChebyshevIterable) = c.resnorm c.reltol
27+
start(::ChebyshevIterable) = 0
28+
done(c::ChebyshevIterable, iteration::Int) = iteration c.maxiter || converged(c)
29+
30+
function next(cheb::ChebyshevIterable, iteration::Int)
31+
T = eltype(cheb.u)
32+
33+
solve!(cheb.c, cheb.Pl, cheb.r)
34+
35+
if iteration == 1
36+
cheb.α = T(2) / cheb.λ_avg
37+
copy!(cheb.u, cheb.c)
38+
else
39+
β = (cheb.λ_diff * cheb.α / 2) ^ 2
40+
cheb.α = inv(cheb.λ_avg - β)
41+
42+
# Almost an axpy u = c + βu
43+
scale!(cheb.u, β)
44+
@blas! cheb.u += one(T) * cheb.c
45+
end
46+
47+
A_mul_B!(cheb.c, cheb.A, cheb.u)
48+
cheb.mv_products += 1
49+
50+
@blas! cheb.x += cheb.α * cheb.u
51+
@blas! cheb.r -= cheb.α * cheb.c
52+
53+
cheb.resnorm = norm(cheb.r)
54+
55+
cheb.resnorm, iteration + 1
56+
end
57+
58+
chebyshev_iterable(A, b, λmin::Real, λmax::Real; kwargs...) =
59+
chebyshev_iterable!(zerox(A, b), A, b, λmin, λmax; kwargs...)
60+
61+
function chebyshev_iterable!(x, A, b, λmin::Real, λmax::Real;
62+
tol = sqrt(eps(real(eltype(b)))), maxiter = size(A, 1), Pl = Identity(), initially_zero = false)
63+
64+
λ_avg = (λmax + λmin) / 2
65+
λ_diff = (λmax - λmin) / 2
66+
67+
T = eltype(b)
68+
r = copy(b)
69+
u = zeros(x)
70+
c = similar(x)
71+
72+
# One MV product less
73+
if initially_zero
74+
resnorm = norm(r)
75+
reltol = tol * resnorm
76+
mv_products = 0
77+
else
78+
A_mul_B!(c, A, x)
79+
@blas! r -= one(T) * c
80+
resnorm = norm(r)
81+
reltol = tol * norm(b)
82+
mv_products = 1
83+
end
84+
85+
ChebyshevIterable(Pl, A, b,
86+
x, r, u, c,
87+
zero(real(T)),
88+
λ_avg, λ_diff,
89+
resnorm, reltol, maxiter, mv_products
90+
)
91+
end
92+
393
####################
494
# API method calls #
595
####################
696

797
chebyshev(A, b, λmin::Real, λmax::Real; kwargs...) =
8-
chebyshev!(zerox(A, b), A, b, λmin, λmax; kwargs...)
98+
chebyshev!(zerox(A, b), A, b, λmin, λmax; initially_zero = true, kwargs...)
999

10100
function chebyshev!(x, A, b, λmin::Real, λmax::Real;
11-
n::Int=size(A,2), tol::Real = sqrt(eps(typeof(real(b[1])))),
12-
maxiter::Int = n^3, log::Bool=false, kwargs...
13-
)
14-
K = KrylovSubspace(A, n, 1, Adivtype(A, b))
15-
init!(K, x)
101+
Pl = Identity(),
102+
tol::Real=sqrt(eps(real(eltype(b)))),
103+
maxiter::Int=size(A, 1),
104+
log::Bool=false,
105+
verbose::Bool=false,
106+
initially_zero::Bool=false
107+
)
16108
history = ConvergenceHistory(partial=!log)
17109
history[:tol] = tol
18-
reserve!(history,:resnorm,maxiter)
19-
chebyshev_method!(history, x, K, b, λmin, λmax; tol=tol, maxiter=maxiter, kwargs...)
20-
log && shrink!(history)
21-
log ? (x, history) : x
22-
end
110+
reserve!(history, :resnorm, maxiter)
23111

24-
#########################
25-
# Method Implementation #
26-
#########################
112+
verbose && @printf("=== chebyshev ===\n%4s\t%7s\n","iter","resnorm")
27113

28-
function chebyshev_method!(
29-
log::ConvergenceHistory, x, K::KrylovSubspace, b, λmin::Real, λmax::Real;
30-
Pr = 1, tol::Real = sqrt(eps(typeof(real(b[1])))), maxiter::Int = K.n^3,
31-
verbose::Bool=false
32-
)
114+
iterable = chebyshev_iterable!(x, A, b, λmin, λmax; tol=tol, maxiter=maxiter, Pl=Pl, initially_zero=initially_zero)
115+
history.mvps = iterable.mv_products
116+
for (iteration, resnorm) = enumerate(iterable)
117+
nextiter!(history)
118+
history.mvps = iterable.mv_products
119+
push!(history, :resnorm, resnorm)
120+
verbose && @printf("%3d\t%1.2e\n", iteration, resnorm)
121+
end
122+
verbose && println()
123+
setconv(history, converged(iterable))
124+
log && shrink!(history)
33125

34-
verbose && @printf("=== chebyshev ===\n%4s\t%7s\n","iter","resnorm")
35-
local α, p
36-
K.order = 1
37-
tol = tol*norm(b)
38-
log.mvps=1
39-
r = b - nextvec(K)
40-
d::eltype(b) = (λmax + λmin)/2
41-
c::eltype(b) = (λmax - λmin)/2
42-
for iter = 1:maxiter
43-
nextiter!(log, mvps=1)
44-
z = solve(Pr,r)
45-
if iter == 1
46-
p = z
47-
α = 2/d
48-
else
49-
β = (c*α/2)^2
50-
α = 1/(d - β)
51-
p = z + β*p
52-
end
53-
append!(K, p)
54-
@blas! x += α*p
55-
@blas! r -= α*nextvec(K)
56-
#Check convergence
57-
resnorm = norm(r)
58-
push!(log, :resnorm, resnorm)
59-
verbose && @printf("%3d\t%1.2e\n",iter,resnorm)
60-
resnorm < tol && break
61-
end
62-
verbose && @printf("\n")
63-
setconv(log, 0<=norm(r)<tol)
64-
x
126+
log ? (x, history) : x
65127
end
66128

67129
#################
@@ -108,7 +170,7 @@ $arg
108170
109171
## Keywords
110172
111-
`Pr = 1`: right preconditioner of the method.
173+
`Pl = 1`: left preconditioner of the method.
112174
113175
`tol::Real = sqrt(eps())`: stopping tolerance.
114176

test/chebyshev.jl

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,41 @@ using LinearMaps
55

66
srand(1234321)
77

8+
function randSPD(T, n)
9+
A = rand(T, n, n) + n * eye(T, n)
10+
A = A' + A
11+
A = A' * A
12+
end
13+
814
#Chebyshev
915
facts("Chebyshev") do
16+
n = 10
1017
for T in (Float32, Float64, Complex64, Complex128)
1118
context("Matrix{$T}") do
12-
A=convert(Matrix{T}, randn(n,n))
13-
T<:Complex && (A+=convert(Matrix{T}, im*randn(n,n)))
14-
A=A+A'
15-
A=A'*A #Construct SPD matrix
16-
b=convert(Vector{T}, randn(n))
17-
T<:Complex && (b+=convert(Vector{T}, im*randn(n)))
18-
b=b/norm(b)
19-
tol = 0.1 #For some reason Chebyshev is very slow
20-
v = eigvals(A)
21-
mxv = maximum(v)
22-
mnv = minimum(v)
23-
x_cheby, c_cheby= chebyshev(A, b, mxv+(mxv-mnv)/100, mnv-(mxv-mnv)/100, tol=tol, maxiter=10^5, log=true)
24-
@fact c_cheby.isconverged --> true
25-
@fact norm(A*x_cheby-b) --> less_than(tol)
19+
A = randSPD(T, n)
20+
b = rand(T, n)
21+
b /= norm(b)
22+
23+
tol = sqrt(eps(real(T)))
24+
mnv, mxv = eigmin(A), eigmax(A)
25+
Δ = (mxv - mnv) / 100
26+
27+
x, history = chebyshev(A, b, mnv - Δ, mxv + Δ, tol=tol, maxiter=10n, log=true)
28+
@fact history.isconverged --> true
29+
@fact norm(A * x - b) --> less_than(tol)
30+
31+
context("Preconditioned") do
32+
B = randSPD(T, n)
33+
B_fact = cholfact!(B)
34+
BA = B_fact \ A
35+
λs = eigvals(BA)
36+
mnv, mxv = minimum(real(λs)), maximum(real(λs))
37+
Δ = (mxv - mnv) / 100
38+
39+
x, history = chebyshev(A, b, mnv - Δ, mxv + Δ, Pl = B_fact, tol=tol, maxiter=10n, log=true)
40+
@fact history.isconverged --> true
41+
@fact norm(A * x - b) --> less_than(tol)
42+
end
2643
end
2744
end
2845
end

0 commit comments

Comments
 (0)