|
1 |
| -export cg, cg! |
2 |
| - |
3 |
| -cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; kwargs...) |
| 1 | +import Base: start, next, done |
| 2 | + |
| 3 | +export cg, cg!, CGIterable, PCGIterable, cg_iterator, cg_iterator! |
| 4 | + |
| 5 | +type CGIterable{matT, vecT <: AbstractVector, numT <: Real} |
| 6 | + A::matT |
| 7 | + x::vecT |
| 8 | + b::vecT |
| 9 | + r::vecT |
| 10 | + c::vecT |
| 11 | + u::vecT |
| 12 | + reltol::numT |
| 13 | + residual::numT |
| 14 | + prev_residual::numT |
| 15 | + maxiter::Int |
| 16 | + mv_products::Int |
| 17 | +end |
4 | 18 |
|
5 |
| -function cg!(x, A, b; |
6 |
| - tol = sqrt(eps(real(eltype(b)))), |
7 |
| - maxiter::Integer = min(20, size(A, 1)), |
8 |
| - log::Bool = false, |
9 |
| - Pl = Identity(), |
10 |
| - kwargs... |
11 |
| -) |
12 |
| - history = ConvergenceHistory(partial = !log) |
13 |
| - history[:tol] = tol |
14 |
| - log && reserve!(history, :resnorm, maxiter + 1) |
15 |
| - cg_method!(history, x, A, b, Pl; tol = tol, log = log, maxiter = maxiter, kwargs...) |
16 |
| - log && shrink!(history) |
17 |
| - log ? (x, history) : x |
| 19 | +type PCGIterable{precT, matT, vecT <: AbstractVector, numT <: Real, paramT <: Number} |
| 20 | + Pl::precT |
| 21 | + A::matT |
| 22 | + x::vecT |
| 23 | + b::vecT |
| 24 | + r::vecT |
| 25 | + c::vecT |
| 26 | + u::vecT |
| 27 | + reltol::numT |
| 28 | + residual::numT |
| 29 | + ρ::paramT |
| 30 | + maxiter::Int |
| 31 | + mv_products::Int |
18 | 32 | end
|
19 | 33 |
|
20 |
| -function cg_method!(history::ConvergenceHistory, x, A, b, Pl; |
21 |
| - tol = sqrt(eps(real(eltype(b)))), |
22 |
| - maxiter::Integer = min(20, size(A, 1)), |
23 |
| - verbose::Bool = false, |
24 |
| - log = false |
25 |
| -) |
26 |
| - # Preconditioned CG |
27 |
| - T = eltype(b) |
28 |
| - n = size(A, 1) |
| 34 | +@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.reltol |
29 | 35 |
|
30 |
| - # Initial residual vector |
31 |
| - r = copy(b) |
32 |
| - c = A * x |
33 |
| - @blas! r -= one(T) * c |
34 |
| - u = zeros(T, n) |
35 |
| - ρ = one(T) |
| 36 | +@inline start(it::Union{CGIterable, PCGIterable}) = 0 |
36 | 37 |
|
37 |
| - if log |
38 |
| - history.mvps += 1 |
39 |
| - end |
40 |
| - |
41 |
| - iter = 1 |
| 38 | +@inline done(it::Union{CGIterable, PCGIterable}, iteration::Int) = iteration ≥ it.maxiter || converged(it) |
42 | 39 |
|
43 |
| - # Here you could save one inner product if norm(r) is used rather than norm(b) |
44 |
| - reltol = norm(b) * tol |
45 |
| - last_residual = zero(T) |
46 | 40 |
|
47 |
| - while true |
| 41 | +############### |
| 42 | +# Ordinary CG # |
| 43 | +############### |
48 | 44 |
|
49 |
| - last_residual = norm(r) |
| 45 | +function next(it::CGIterable, iteration::Int) |
| 46 | + # u := r + βu (almost an axpy) |
| 47 | + β = it.residual^2 / it.prev_residual^2 |
| 48 | + @blas! it.u *= β |
| 49 | + @blas! it.u += one(eltype(it.b)) * it.r |
50 | 50 |
|
51 |
| - verbose && @printf("%3d\t%1.2e\n", iter, last_residual) |
| 51 | + # c = A * u |
| 52 | + A_mul_B!(it.c, it.A, it.u) |
| 53 | + α = it.residual^2 / dot(it.u, it.c) |
52 | 54 |
|
53 |
| - if last_residual ≤ reltol || iter > maxiter |
54 |
| - break |
55 |
| - end |
| 55 | + # Improve solution and residual |
| 56 | + @blas! it.x += α * it.u |
| 57 | + @blas! it.r -= α * it.c |
56 | 58 |
|
57 |
| - # Log progress |
58 |
| - if log |
59 |
| - nextiter!(history, mvps = 1) |
60 |
| - push!(history, :resnorm, last_residual) |
61 |
| - end |
| 59 | + it.prev_residual = it.residual |
| 60 | + it.residual = norm(it.r) |
62 | 61 |
|
63 |
| - # Preconditioner: c = Pl \ r |
64 |
| - solve!(c, Pl, r) |
| 62 | + # Return the residual at item and iteration number as state |
| 63 | + it.residual, iteration + 1 |
| 64 | +end |
65 | 65 |
|
66 |
| - ρ_prev = ρ |
67 |
| - ρ = dot(c, r) |
68 |
| - β = ρ / ρ_prev |
| 66 | +##################### |
| 67 | +# Preconditioned CG # |
| 68 | +##################### |
69 | 69 |
|
70 |
| - # u := c + βu (almost an axpy) |
71 |
| - @blas! u *= β |
72 |
| - @blas! u += one(T) * c |
| 70 | +function next(it::PCGIterable, iteration::Int) |
| 71 | + solve!(it.c, it.Pl, it.r) |
73 | 72 |
|
74 |
| - # c = A * u |
75 |
| - A_mul_B!(c, A, u) |
76 |
| - α = ρ / dot(u, c) |
| 73 | + ρ_prev = it.ρ |
| 74 | + it.ρ = dot(it.c, it.r) |
77 | 75 |
|
78 |
| - # Improve solution and residual |
79 |
| - @blas! x += α * u |
80 |
| - @blas! r -= α * c |
| 76 | + # u := c + βu (almost an axpy) |
| 77 | + β = it.ρ / ρ_prev |
| 78 | + @blas! it.u *= β |
| 79 | + @blas! it.u += one(eltype(it.b)) * it.c |
81 | 80 |
|
82 |
| - iter += 1 |
83 |
| - end |
| 81 | + # c = A * u |
| 82 | + A_mul_B!(it.c, it.A, it.u) |
| 83 | + α = it.ρ / dot(it.u, it.c) |
84 | 84 |
|
85 |
| - verbose && @printf("\n") |
86 |
| - log && setconv(history, last_residual < reltol) |
| 85 | + # Improve solution and residual |
| 86 | + @blas! it.x += α * it.u |
| 87 | + @blas! it.r -= α * it.c |
87 | 88 |
|
88 |
| - x |
| 89 | + it.residual = norm(it.r) |
| 90 | + |
| 91 | + # Return the residual at item and iteration number as state |
| 92 | + it.residual, iteration + 1 |
89 | 93 | end
|
90 | 94 |
|
91 |
| -function cg_method!(history::ConvergenceHistory, x, A, b, Pl::Identity; |
| 95 | +# Utility functions |
| 96 | + |
| 97 | +@inline cg_iterator(A, b, Pl = Identity(); kwargs...) = cg_iterator!(zerox(A, b), A, b, Pl; initially_zero = true, kwargs...) |
| 98 | + |
| 99 | +function cg_iterator!(x, A, b, Pl = Identity(); |
92 | 100 | tol = sqrt(eps(real(eltype(b)))),
|
93 |
| - maxiter::Integer = min(20, size(A, 1)), |
94 |
| - verbose::Bool = false, |
95 |
| - log = false |
| 101 | + maxiter = min(20, length(b)), |
| 102 | + initially_zero::Bool = false |
96 | 103 | )
|
97 |
| - # Unpreconditioned CG |
98 |
| - T = eltype(b) |
99 |
| - n = size(A, 1) |
100 |
| - |
101 |
| - # Initial residual vector |
| 104 | + u = zeros(x) |
102 | 105 | r = copy(b)
|
103 |
| - c = A * x |
104 |
| - @blas! r -= one(T) * c |
105 |
| - u = zeros(T, n) |
106 |
| - ρ = one(T) |
107 | 106 |
|
108 |
| - if log |
109 |
| - history.mvps += 1 |
| 107 | + # Compute r with an MV-product or not. |
| 108 | + if initially_zero |
| 109 | + mv_products = 0 |
| 110 | + c = similar(x) |
| 111 | + residual = norm(b) |
| 112 | + reltol = residual * tol # Save one dot product |
| 113 | + else |
| 114 | + mv_products = 1 |
| 115 | + c = A * x |
| 116 | + @blas! r -= one(eltype(x)) * c |
| 117 | + residual = norm(r) |
| 118 | + reltol = norm(b) * tol |
110 | 119 | end
|
111 | 120 |
|
112 |
| - iter = 1 |
113 |
| - |
114 |
| - reltol = norm(b) * tol |
115 |
| - last_residual = zero(T) |
116 |
| - |
117 |
| - while true |
| 121 | + # Stopping criterion |
| 122 | + ρ = one(residual) |
| 123 | + |
| 124 | + # Return the iterable |
| 125 | + if isa(Pl, Identity) |
| 126 | + return CGIterable(A, x, b, |
| 127 | + r, c, u, |
| 128 | + reltol, residual, ρ, |
| 129 | + maxiter, mv_products |
| 130 | + ) |
| 131 | + else |
| 132 | + return PCGIterable(Pl, A, x, b, |
| 133 | + r, c, u, |
| 134 | + reltol, residual, ρ, |
| 135 | + maxiter, mv_products |
| 136 | + ) |
| 137 | + end |
| 138 | +end |
118 | 139 |
|
119 |
| - ρ_prev = ρ |
120 |
| - ρ = dot(r, r) |
121 |
| - β = ρ / ρ_prev |
| 140 | +cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...) |
122 | 141 |
|
123 |
| - last_residual = sqrt(ρ) |
| 142 | +function cg!(x, A, b; |
| 143 | + tol = sqrt(eps(real(eltype(b)))), |
| 144 | + maxiter::Integer = min(20, size(A, 1)), |
| 145 | + plot = false, |
| 146 | + log::Bool = false, |
| 147 | + verbose::Bool = false, |
| 148 | + Pl = Identity(), |
| 149 | + kwargs... |
| 150 | +) |
| 151 | + (plot & !log) && error("Can't plot when log keyword is false") |
| 152 | + history = ConvergenceHistory(partial = !log) |
| 153 | + history[:tol] = tol |
| 154 | + log && reserve!(history, :resnorm, maxiter + 1) |
124 | 155 |
|
125 |
| - # Log progress |
| 156 | + # Actually perform CG |
| 157 | + iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, kwargs...) |
| 158 | + if log |
| 159 | + history.mvps = iterable.mv_products |
| 160 | + end |
| 161 | + for (iteration, item) = enumerate(iterable) |
126 | 162 | if log
|
127 | 163 | nextiter!(history, mvps = 1)
|
128 |
| - push!(history, :resnorm, last_residual) |
| 164 | + push!(history, :resnorm, iterable.residual) |
129 | 165 | end
|
130 |
| - |
131 |
| - verbose && @printf("%3d\t%1.2e\n", iter, last_residual) |
132 |
| - |
133 |
| - # Stopping condition |
134 |
| - if last_residual ≤ reltol || iter > maxiter |
135 |
| - break |
136 |
| - end |
137 |
| - |
138 |
| - # u := r + βu (almost an axpy) |
139 |
| - @blas! u *= β |
140 |
| - @blas! u += one(T) * r |
141 |
| - |
142 |
| - # c = A * u |
143 |
| - A_mul_B!(c, A, u) |
144 |
| - α = ρ / dot(u, c) |
145 |
| - |
146 |
| - # Improve solution and residual |
147 |
| - @blas! x += α * u |
148 |
| - @blas! r -= α * c |
149 |
| - |
150 |
| - iter += 1 |
| 166 | + verbose && @printf("%3d\t%1.2e\n", iteration, iterable.residual) |
151 | 167 | end
|
152 | 168 |
|
153 |
| - verbose && @printf("\n") |
154 |
| - log && setconv(history, last_residual < reltol) |
| 169 | + verbose && println() |
| 170 | + log && setconv(history, converged(iterable)) |
| 171 | + log && shrink!(history) |
| 172 | + plot && showplot(history) |
155 | 173 |
|
156 |
| - x |
| 174 | + log ? (iterable.x, history) : iterable.x |
157 | 175 | end
|
158 | 176 |
|
159 |
| - |
160 | 177 | #################
|
161 | 178 | # Documentation #
|
162 | 179 | #################
|
|
0 commit comments