|
| 1 | +import Base: next, start, done |
| 2 | + |
1 | 3 | export chebyshev, chebyshev!
|
2 | 4 |
|
| 5 | +type ChebyshevIterable{precT, matT, vecT, realT <: Real} |
| 6 | + Pl::precT |
| 7 | + A::matT |
| 8 | + b::vecT |
| 9 | + |
| 10 | + x::vecT |
| 11 | + r::vecT |
| 12 | + u::vecT |
| 13 | + c::vecT |
| 14 | + |
| 15 | + α::realT |
| 16 | + |
| 17 | + λ_avg::realT |
| 18 | + λ_diff::realT |
| 19 | + |
| 20 | + resnorm::realT |
| 21 | + reltol::realT |
| 22 | + maxiter::Int |
| 23 | + mv_products::Int |
| 24 | +end |
| 25 | + |
| 26 | +converged(c::ChebyshevIterable) = c.resnorm ≤ c.reltol |
| 27 | +start(::ChebyshevIterable) = 0 |
| 28 | +done(c::ChebyshevIterable, iteration::Int) = iteration ≥ c.maxiter || converged(c) |
| 29 | + |
| 30 | +function next(cheb::ChebyshevIterable, iteration::Int) |
| 31 | + T = eltype(cheb.u) |
| 32 | + |
| 33 | + solve!(cheb.c, cheb.Pl, cheb.r) |
| 34 | + |
| 35 | + if iteration == 1 |
| 36 | + cheb.α = T(2) / cheb.λ_avg |
| 37 | + copy!(cheb.u, cheb.c) |
| 38 | + else |
| 39 | + β = (cheb.λ_diff * cheb.α / 2) ^ 2 |
| 40 | + cheb.α = inv(cheb.λ_avg - β) |
| 41 | + |
| 42 | + # Almost an axpy u = c + βu |
| 43 | + scale!(cheb.u, β) |
| 44 | + @blas! cheb.u += one(T) * cheb.c |
| 45 | + end |
| 46 | + |
| 47 | + A_mul_B!(cheb.c, cheb.A, cheb.u) |
| 48 | + cheb.mv_products += 1 |
| 49 | + |
| 50 | + @blas! cheb.x += cheb.α * cheb.u |
| 51 | + @blas! cheb.r -= cheb.α * cheb.c |
| 52 | + |
| 53 | + cheb.resnorm = norm(cheb.r) |
| 54 | + |
| 55 | + cheb.resnorm, iteration + 1 |
| 56 | +end |
| 57 | + |
| 58 | +chebyshev_iterable(A, b, λmin::Real, λmax::Real; kwargs...) = |
| 59 | + chebyshev_iterable!(zerox(A, b), A, b, λmin, λmax; kwargs...) |
| 60 | + |
| 61 | +function chebyshev_iterable!(x, A, b, λmin::Real, λmax::Real; |
| 62 | + tol = sqrt(eps(real(eltype(b)))), maxiter = size(A, 1), Pl = Identity(), initially_zero = false) |
| 63 | + |
| 64 | + λ_avg = (λmax + λmin) / 2 |
| 65 | + λ_diff = (λmax - λmin) / 2 |
| 66 | + |
| 67 | + T = eltype(b) |
| 68 | + r = copy(b) |
| 69 | + u = zeros(x) |
| 70 | + c = similar(x) |
| 71 | + |
| 72 | + # One MV product less |
| 73 | + if initially_zero |
| 74 | + resnorm = norm(r) |
| 75 | + reltol = tol * resnorm |
| 76 | + mv_products = 0 |
| 77 | + else |
| 78 | + A_mul_B!(c, A, x) |
| 79 | + @blas! r -= one(T) * c |
| 80 | + resnorm = norm(r) |
| 81 | + reltol = tol * norm(b) |
| 82 | + mv_products = 1 |
| 83 | + end |
| 84 | + |
| 85 | + ChebyshevIterable(Pl, A, b, |
| 86 | + x, r, u, c, |
| 87 | + zero(real(T)), |
| 88 | + λ_avg, λ_diff, |
| 89 | + resnorm, reltol, maxiter, mv_products |
| 90 | + ) |
| 91 | +end |
| 92 | + |
3 | 93 | ####################
|
4 | 94 | # API method calls #
|
5 | 95 | ####################
|
6 | 96 |
|
7 | 97 | chebyshev(A, b, λmin::Real, λmax::Real; kwargs...) =
|
8 |
| - chebyshev!(zerox(A, b), A, b, λmin, λmax; kwargs...) |
| 98 | + chebyshev!(zerox(A, b), A, b, λmin, λmax; initially_zero = true, kwargs...) |
9 | 99 |
|
10 | 100 | function chebyshev!(x, A, b, λmin::Real, λmax::Real;
|
11 |
| - n::Int=size(A,2), tol::Real = sqrt(eps(typeof(real(b[1])))), |
12 |
| - maxiter::Int = n^3, log::Bool=false, kwargs... |
13 |
| - ) |
14 |
| - K = KrylovSubspace(A, n, 1, Adivtype(A, b)) |
15 |
| - init!(K, x) |
| 101 | + Pl = Identity(), |
| 102 | + tol::Real=sqrt(eps(real(eltype(b)))), |
| 103 | + maxiter::Int=size(A, 1), |
| 104 | + log::Bool=false, |
| 105 | + verbose::Bool=false, |
| 106 | + initially_zero::Bool=false |
| 107 | +) |
16 | 108 | history = ConvergenceHistory(partial=!log)
|
17 | 109 | history[:tol] = tol
|
18 |
| - reserve!(history,:resnorm,maxiter) |
19 |
| - chebyshev_method!(history, x, K, b, λmin, λmax; tol=tol, maxiter=maxiter, kwargs...) |
20 |
| - log && shrink!(history) |
21 |
| - log ? (x, history) : x |
22 |
| -end |
| 110 | + reserve!(history, :resnorm, maxiter) |
23 | 111 |
|
24 |
| -######################### |
25 |
| -# Method Implementation # |
26 |
| -######################### |
| 112 | + verbose && @printf("=== chebyshev ===\n%4s\t%7s\n","iter","resnorm") |
27 | 113 |
|
28 |
| -function chebyshev_method!( |
29 |
| - log::ConvergenceHistory, x, K::KrylovSubspace, b, λmin::Real, λmax::Real; |
30 |
| - Pr = 1, tol::Real = sqrt(eps(typeof(real(b[1])))), maxiter::Int = K.n^3, |
31 |
| - verbose::Bool=false |
32 |
| - ) |
| 114 | + iterable = chebyshev_iterable!(x, A, b, λmin, λmax; tol=tol, maxiter=maxiter, Pl=Pl, initially_zero=initially_zero) |
| 115 | + history.mvps = iterable.mv_products |
| 116 | + for (iteration, resnorm) = enumerate(iterable) |
| 117 | + nextiter!(history) |
| 118 | + history.mvps = iterable.mv_products |
| 119 | + push!(history, :resnorm, resnorm) |
| 120 | + verbose && @printf("%3d\t%1.2e\n", iteration, resnorm) |
| 121 | + end |
| 122 | + verbose && println() |
| 123 | + setconv(history, converged(iterable)) |
| 124 | + log && shrink!(history) |
33 | 125 |
|
34 |
| - verbose && @printf("=== chebyshev ===\n%4s\t%7s\n","iter","resnorm") |
35 |
| - local α, p |
36 |
| - K.order = 1 |
37 |
| - tol = tol*norm(b) |
38 |
| - log.mvps=1 |
39 |
| - r = b - nextvec(K) |
40 |
| - d::eltype(b) = (λmax + λmin)/2 |
41 |
| - c::eltype(b) = (λmax - λmin)/2 |
42 |
| - for iter = 1:maxiter |
43 |
| - nextiter!(log, mvps=1) |
44 |
| - z = solve(Pr,r) |
45 |
| - if iter == 1 |
46 |
| - p = z |
47 |
| - α = 2/d |
48 |
| - else |
49 |
| - β = (c*α/2)^2 |
50 |
| - α = 1/(d - β) |
51 |
| - p = z + β*p |
52 |
| - end |
53 |
| - append!(K, p) |
54 |
| - @blas! x += α*p |
55 |
| - @blas! r -= α*nextvec(K) |
56 |
| - #Check convergence |
57 |
| - resnorm = norm(r) |
58 |
| - push!(log, :resnorm, resnorm) |
59 |
| - verbose && @printf("%3d\t%1.2e\n",iter,resnorm) |
60 |
| - resnorm < tol && break |
61 |
| - end |
62 |
| - verbose && @printf("\n") |
63 |
| - setconv(log, 0<=norm(r)<tol) |
64 |
| - x |
| 126 | + log ? (x, history) : x |
65 | 127 | end
|
66 | 128 |
|
67 | 129 | #################
|
|
108 | 170 |
|
109 | 171 | ## Keywords
|
110 | 172 |
|
111 |
| -`Pr = 1`: right preconditioner of the method. |
| 173 | +`Pl = 1`: left preconditioner of the method. |
112 | 174 |
|
113 | 175 | `tol::Real = sqrt(eps())`: stopping tolerance.
|
114 | 176 |
|
|
0 commit comments