Skip to content

Commit 2ec9f11

Browse files
haampieandreasnoack
authored andcommitted
CG with iterators (#137)
* Add a new iterator variant for CG * PCG * Don't overspecify * Restore stopping condition & save one norm if x = 0 initially * Support for Julia 0.5
1 parent 3224fb2 commit 2ec9f11

File tree

2 files changed

+140
-122
lines changed

2 files changed

+140
-122
lines changed

src/cg.jl

Lines changed: 138 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,179 @@
1-
export cg, cg!
2-
3-
cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; kwargs...)
1+
import Base: start, next, done
2+
3+
export cg, cg!, CGIterable, PCGIterable, cg_iterator, cg_iterator!
4+
5+
type CGIterable{matT, vecT <: AbstractVector, numT <: Real}
6+
A::matT
7+
x::vecT
8+
b::vecT
9+
r::vecT
10+
c::vecT
11+
u::vecT
12+
reltol::numT
13+
residual::numT
14+
prev_residual::numT
15+
maxiter::Int
16+
mv_products::Int
17+
end
418

5-
function cg!(x, A, b;
6-
tol = sqrt(eps(real(eltype(b)))),
7-
maxiter::Integer = min(20, size(A, 1)),
8-
log::Bool = false,
9-
Pl = Identity(),
10-
kwargs...
11-
)
12-
history = ConvergenceHistory(partial = !log)
13-
history[:tol] = tol
14-
log && reserve!(history, :resnorm, maxiter + 1)
15-
cg_method!(history, x, A, b, Pl; tol = tol, log = log, maxiter = maxiter, kwargs...)
16-
log && shrink!(history)
17-
log ? (x, history) : x
19+
type PCGIterable{precT, matT, vecT <: AbstractVector, numT <: Real, paramT <: Number}
20+
Pl::precT
21+
A::matT
22+
x::vecT
23+
b::vecT
24+
r::vecT
25+
c::vecT
26+
u::vecT
27+
reltol::numT
28+
residual::numT
29+
ρ::paramT
30+
maxiter::Int
31+
mv_products::Int
1832
end
1933

20-
function cg_method!(history::ConvergenceHistory, x, A, b, Pl;
21-
tol = sqrt(eps(real(eltype(b)))),
22-
maxiter::Integer = min(20, size(A, 1)),
23-
verbose::Bool = false,
24-
log = false
25-
)
26-
# Preconditioned CG
27-
T = eltype(b)
28-
n = size(A, 1)
34+
@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual it.reltol
2935

30-
# Initial residual vector
31-
r = copy(b)
32-
c = A * x
33-
@blas! r -= one(T) * c
34-
u = zeros(T, n)
35-
ρ = one(T)
36+
@inline start(it::Union{CGIterable, PCGIterable}) = 0
3637

37-
if log
38-
history.mvps += 1
39-
end
40-
41-
iter = 1
38+
@inline done(it::Union{CGIterable, PCGIterable}, iteration::Int) = iteration it.maxiter || converged(it)
4239

43-
# Here you could save one inner product if norm(r) is used rather than norm(b)
44-
reltol = norm(b) * tol
45-
last_residual = zero(T)
4640

47-
while true
41+
###############
42+
# Ordinary CG #
43+
###############
4844

49-
last_residual = norm(r)
45+
function next(it::CGIterable, iteration::Int)
46+
# u := r + βu (almost an axpy)
47+
β = it.residual^2 / it.prev_residual^2
48+
@blas! it.u *= β
49+
@blas! it.u += one(eltype(it.b)) * it.r
5050

51-
verbose && @printf("%3d\t%1.2e\n", iter, last_residual)
51+
# c = A * u
52+
A_mul_B!(it.c, it.A, it.u)
53+
α = it.residual^2 / dot(it.u, it.c)
5254

53-
if last_residual reltol || iter > maxiter
54-
break
55-
end
55+
# Improve solution and residual
56+
@blas! it.x += α * it.u
57+
@blas! it.r -= α * it.c
5658

57-
# Log progress
58-
if log
59-
nextiter!(history, mvps = 1)
60-
push!(history, :resnorm, last_residual)
61-
end
59+
it.prev_residual = it.residual
60+
it.residual = norm(it.r)
6261

63-
# Preconditioner: c = Pl \ r
64-
solve!(c, Pl, r)
62+
# Return the residual at item and iteration number as state
63+
it.residual, iteration + 1
64+
end
6565

66-
ρ_prev = ρ
67-
ρ = dot(c, r)
68-
β = ρ / ρ_prev
66+
#####################
67+
# Preconditioned CG #
68+
#####################
6969

70-
# u := c + βu (almost an axpy)
71-
@blas! u *= β
72-
@blas! u += one(T) * c
70+
function next(it::PCGIterable, iteration::Int)
71+
solve!(it.c, it.Pl, it.r)
7372

74-
# c = A * u
75-
A_mul_B!(c, A, u)
76-
α = ρ / dot(u, c)
73+
ρ_prev = it.ρ
74+
it.ρ = dot(it.c, it.r)
7775

78-
# Improve solution and residual
79-
@blas! x += α * u
80-
@blas! r -= α * c
76+
# u := c + βu (almost an axpy)
77+
β = it.ρ / ρ_prev
78+
@blas! it.u *= β
79+
@blas! it.u += one(eltype(it.b)) * it.c
8180

82-
iter += 1
83-
end
81+
# c = A * u
82+
A_mul_B!(it.c, it.A, it.u)
83+
α = it.ρ / dot(it.u, it.c)
8484

85-
verbose && @printf("\n")
86-
log && setconv(history, last_residual < reltol)
85+
# Improve solution and residual
86+
@blas! it.x += α * it.u
87+
@blas! it.r -= α * it.c
8788

88-
x
89+
it.residual = norm(it.r)
90+
91+
# Return the residual at item and iteration number as state
92+
it.residual, iteration + 1
8993
end
9094

91-
function cg_method!(history::ConvergenceHistory, x, A, b, Pl::Identity;
95+
# Utility functions
96+
97+
@inline cg_iterator(A, b, Pl = Identity(); kwargs...) = cg_iterator!(zerox(A, b), A, b, Pl; initially_zero = true, kwargs...)
98+
99+
function cg_iterator!(x, A, b, Pl = Identity();
92100
tol = sqrt(eps(real(eltype(b)))),
93-
maxiter::Integer = min(20, size(A, 1)),
94-
verbose::Bool = false,
95-
log = false
101+
maxiter = min(20, length(b)),
102+
initially_zero::Bool = false
96103
)
97-
# Unpreconditioned CG
98-
T = eltype(b)
99-
n = size(A, 1)
100-
101-
# Initial residual vector
104+
u = zeros(x)
102105
r = copy(b)
103-
c = A * x
104-
@blas! r -= one(T) * c
105-
u = zeros(T, n)
106-
ρ = one(T)
107106

108-
if log
109-
history.mvps += 1
107+
# Compute r with an MV-product or not.
108+
if initially_zero
109+
mv_products = 0
110+
c = similar(x)
111+
residual = norm(b)
112+
reltol = residual * tol # Save one dot product
113+
else
114+
mv_products = 1
115+
c = A * x
116+
@blas! r -= one(eltype(x)) * c
117+
residual = norm(r)
118+
reltol = norm(b) * tol
110119
end
111120

112-
iter = 1
113-
114-
reltol = norm(b) * tol
115-
last_residual = zero(T)
116-
117-
while true
121+
# Stopping criterion
122+
ρ = one(residual)
123+
124+
# Return the iterable
125+
if isa(Pl, Identity)
126+
return CGIterable(A, x, b,
127+
r, c, u,
128+
reltol, residual, ρ,
129+
maxiter, mv_products
130+
)
131+
else
132+
return PCGIterable(Pl, A, x, b,
133+
r, c, u,
134+
reltol, residual, ρ,
135+
maxiter, mv_products
136+
)
137+
end
138+
end
118139

119-
ρ_prev = ρ
120-
ρ = dot(r, r)
121-
β = ρ / ρ_prev
140+
cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...)
122141

123-
last_residual = sqrt(ρ)
142+
function cg!(x, A, b;
143+
tol = sqrt(eps(real(eltype(b)))),
144+
maxiter::Integer = min(20, size(A, 1)),
145+
plot = false,
146+
log::Bool = false,
147+
verbose::Bool = false,
148+
Pl = Identity(),
149+
kwargs...
150+
)
151+
(plot & !log) && error("Can't plot when log keyword is false")
152+
history = ConvergenceHistory(partial = !log)
153+
history[:tol] = tol
154+
log && reserve!(history, :resnorm, maxiter + 1)
124155

125-
# Log progress
156+
# Actually perform CG
157+
iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, kwargs...)
158+
if log
159+
history.mvps = iterable.mv_products
160+
end
161+
for (iteration, item) = enumerate(iterable)
126162
if log
127163
nextiter!(history, mvps = 1)
128-
push!(history, :resnorm, last_residual)
164+
push!(history, :resnorm, iterable.residual)
129165
end
130-
131-
verbose && @printf("%3d\t%1.2e\n", iter, last_residual)
132-
133-
# Stopping condition
134-
if last_residual reltol || iter > maxiter
135-
break
136-
end
137-
138-
# u := r + βu (almost an axpy)
139-
@blas! u *= β
140-
@blas! u += one(T) * r
141-
142-
# c = A * u
143-
A_mul_B!(c, A, u)
144-
α = ρ / dot(u, c)
145-
146-
# Improve solution and residual
147-
@blas! x += α * u
148-
@blas! r -= α * c
149-
150-
iter += 1
166+
verbose && @printf("%3d\t%1.2e\n", iteration, iterable.residual)
151167
end
152168

153-
verbose && @printf("\n")
154-
log && setconv(history, last_residual < reltol)
169+
verbose && println()
170+
log && setconv(history, converged(iterable))
171+
log && shrink!(history)
172+
plot && showplot(history)
155173

156-
x
174+
log ? (iterable.x, history) : iterable.x
157175
end
158176

159-
160177
#################
161178
# Documentation #
162179
#################

test/cg.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ context("Sparse Laplacian") do
4343
L = tril(A)
4444
D = diag(A)
4545
U = triu(A)
46+
4647
JAC(x) = D .\ x
4748
SGS(x) = L \ (D .* (U \ x))
4849

@@ -72,7 +73,7 @@ context("Sparse Laplacian") do
7273
context("function with specified starting guess") do
7374
tol = 1e-4
7475
x0 = randn(size(rhs))
75-
xCG, hCG = cg!(copy(x0), Af, rhs; Pl=identity, tol=tol, maxiter=100, log=true)
76+
xCG, hCG = cg!(copy(x0), Af, rhs; tol=tol, maxiter=100, log=true)
7677
xJAC, hJAC = cg!(copy(x0), Af, rhs; Pl=JAC, tol=tol, maxiter=100, log=true)
7778
xSGS, hSGS = cg!(copy(x0), Af, rhs; Pl=SGS, tol=tol, maxiter=100, log=true)
7879
@fact norm(A * xCG - rhs) --> less_than_or_equal(tol)

0 commit comments

Comments
 (0)