Skip to content

Commit d48a419

Browse files
haampieandreasnoack
authored andcommitted
Implement GMRES as an iterator (#143)
1 parent 8da9184 commit d48a419

File tree

2 files changed

+163
-129
lines changed

2 files changed

+163
-129
lines changed

benchmark/benchmark-linear-systems.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,10 @@ function gmres(; n = 100_000, tol = 1e-5, restart::Int = 15, maxiter::Int = 210)
5959

6060
println("Matrix of size ", n, " with ~", nnz(A) / n, " nonzeros per row")
6161
println("Tolerance = ", tol, "; restart = ", restart, "; max #iterations = ", maxiter)
62-
63-
impr = @benchmark IterativeSolvers.improved_gmres($A, $b, tol = $tol, restart = $restart, maxiter = $maxiter, log = false)
64-
old = @benchmark IterativeSolvers.gmres($A, $b, tol = $tol, restart = $restart, maxiter = $maxiter, log = false)
6562

66-
impr, old
63+
IterativeSolvers.gmres(A, b, tol = tol, restart = restart, maxiter = maxiter, verbose = true)
64+
65+
@benchmark IterativeSolvers.gmres($A, $b, tol = $tol, restart = $restart, maxiter = $maxiter)
6766
end
6867

6968
function bicgstabl()

src/gmres.jl

Lines changed: 160 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,143 +1,169 @@
1+
import Base: start, next, done
2+
13
export gmres, gmres!
24

3-
gmres(A, b; kwargs...) = gmres!(zeros(b), A, b; kwargs...)
5+
type ArnoldiDecomp{T, matT}
6+
A::matT
7+
V::Matrix{T} # Orthonormal basis vectors
8+
H::Matrix{T} # Hessenberg matrix
9+
end
410

5-
function gmres!(x, A, b;
6-
Pl = Identity(),
7-
Pr = Identity(),
8-
tol = sqrt(eps(real(eltype(b)))),
9-
restart::Int = min(20, length(b)),
10-
maxiter::Int = restart,
11-
log::Bool = false,
12-
kwargs...
11+
ArnoldiDecomp{matT}(A::matT, order::Int, T::Type) = ArnoldiDecomp{T, matT}(
12+
A,
13+
zeros(T, size(A, 1), order + 1),
14+
zeros(T, order + 1, order)
1315
)
14-
history = ConvergenceHistory(partial = !log, restart = restart)
15-
history[:tol] = tol
16-
log && reserve!(history, :resnorm, maxiter)
17-
gmres_method!(history, x, A, b; Pl = Pl, Pr = Pr, tol = tol, maxiter = maxiter, restart = restart, log = log, kwargs...)
18-
log && shrink!(history)
19-
log ? (x, history) : x
16+
17+
type Residual{numT, resT}
18+
current::resT # Current, absolute, preconditioned residual
19+
accumulator::resT # Used to compute the residual on the go
20+
nullvec::Vector{numT} # Vector in the null space of H to compute residuals
21+
β::resT # the initial residual
2022
end
2123

22-
function gmres_method!(history::ConvergenceHistory, x, A, b;
23-
Pl = Identity(),
24-
Pr = Identity(),
25-
tol = sqrt(eps(real(eltype(b)))),
26-
restart::Int = min(20, length(b)),
27-
outer::Int = 1,
28-
maxiter::Int = restart,
29-
verbose::Bool = false,
30-
log = false
24+
Residual(order, T::Type) = Residual{T, real(T)}(
25+
one(real(T)),
26+
one(real(T)),
27+
ones(T, order + 1),
28+
one(real(T))
3129
)
32-
T = eltype(b)
3330

34-
# Approximate solution
35-
arnoldi = ArnoldiDecomp(A, restart, T)
36-
residual = Residual(restart, T)
31+
type GMRESIterable{preclT, precrT, vecT <: AbstractVector, arnoldiT <: ArnoldiDecomp, residualT <: Residual, resT <: Real}
32+
Pl::preclT
33+
Pr::precrT
34+
x::vecT
35+
b::vecT
36+
Ax::vecT # Some room to work in.
37+
38+
arnoldi::arnoldiT
39+
residual::residualT
40+
41+
mv_products::Int
42+
restart::Int
43+
k::Int
44+
maxiter::Int
45+
reltol::resT
46+
β::resT
47+
end
3748

38-
# Workspace vector to reduce the # allocs.
39-
reserved_vec = similar(b)
40-
β = residual.current = init!(arnoldi, x, b, Pl, reserved_vec)
41-
init_residual!(residual, β)
49+
converged(g::GMRESIterable) = g.residual.current g.reltol
4250

43-
# Log the first mvp for computing the initial residual
44-
if log
45-
history.mvps += 1
46-
end
51+
start(::GMRESIterable) = 0
4752

48-
# Stopping criterion is based on |r0| / |rk|
49-
reltol = residual.current * tol
53+
done(g::GMRESIterable, iteration::Int) = iteration g.maxiter || converged(g)
5054

51-
# Total iterations (not reset after restart)
52-
total_iter = 1
55+
function next(g::GMRESIterable, iteration::Int)
5356

54-
while total_iter maxiter
57+
# Arnoldi step: expand
58+
expand!(g.arnoldi, g.Pl, g.Pr, g.k, g.Ax)
59+
g.mv_products += 1
5560

56-
# We already have the initial residual
57-
if total_iter > 1
61+
# Orthogonalize V[:, k + 1] w.r.t. V[:, 1 : k]
62+
g.arnoldi.H[g.k + 1, g.k] = orthogonalize_and_normalize!(
63+
view(g.arnoldi.V, :, 1 : g.k),
64+
view(g.arnoldi.V, :, g.k + 1),
65+
view(g.arnoldi.H, 1 : g.k, g.k)
66+
)
5867

59-
# Set the first basis vector
60-
β = init!(arnoldi, x, b, Pl, reserved_vec)
68+
# Implicitly computes the residual
69+
update_residual!(g.residual, g.arnoldi, g.k)
6170

62-
# And initialize the residual
63-
init_residual!(residual, β)
64-
65-
if log
66-
history.mvps += 1
67-
end
68-
end
71+
g.k += 1
6972

70-
# Inner iterations k = 1, ..., restart
71-
k = 1
72-
73-
while residual.current > reltol && k restart && total_iter maxiter
74-
75-
# Arnoldi step: expand
76-
expand!(arnoldi, Pl, Pr, k)
77-
78-
# Orthogonalize V[:, k + 1] w.r.t. V[:, 1 : k]
79-
arnoldi.H[k + 1, k] = orthogonalize_and_normalize!(
80-
view(arnoldi.V, :, 1 : k),
81-
view(arnoldi.V, :, k + 1),
82-
view(arnoldi.H, 1 : k, k)
83-
)
84-
85-
# Implicitly computes the residual
86-
update_residual!(residual, arnoldi, k)
87-
88-
if log
89-
nextiter!(history, mvps = 1)
90-
push!(history, :resnorm, residual.current)
91-
end
92-
93-
verbose && @printf("%3d\t%3d\t%1.2e\n", mod(total_iter, restart), k, residual.current)
94-
95-
k += 1
96-
total_iter += 1
97-
end
73+
# Computation of x only at the end of the iterations
74+
# and at restart.
75+
if g.k == g.restart + 1 || done(g, iteration + 1)
9876

9977
# Solve the projected problem Hy = β * e1 in the least-squares sense
100-
rhs = solve_least_squares!(arnoldi, β, k)
78+
rhs = solve_least_squares!(g.arnoldi, g.β, g.k)
10179

10280
# And improve the solution x ← x + Pr \ (V * y)
103-
update_solution!(x, view(rhs, 1 : k - 1), arnoldi, Pr, k)
104-
105-
# Converged?
106-
if residual.current reltol
107-
setconv(history, true)
108-
break
81+
update_solution!(g.x, view(rhs, 1 : g.k - 1), g.arnoldi, g.Pr, g.k, g.Ax)
82+
83+
g.k = 1
84+
85+
# Restart when not done.
86+
if !done(g, iteration)
87+
88+
# Set the first basis vector
89+
g.β = init!(g.arnoldi, g.x, g.b, g.Pl, g.Ax)
90+
91+
# And initialize the residual
92+
init_residual!(g.residual, g.β)
93+
94+
g.mv_products += 1
10995
end
11096
end
11197

112-
verbose && @printf("\n")
113-
x
98+
g.residual.current, iteration + 1
11499
end
115100

116-
type ArnoldiDecomp{T}
117-
A
118-
V::Matrix{T} # Orthonormal basis vectors
119-
H::Matrix{T} # Hessenberg matrix
120-
end
101+
gmres_iterable(A, b; kwargs...) = gmres_iterable!(zeros(b), A, b; initially_zero = true, kwargs...)
121102

122-
ArnoldiDecomp(A, order::Int, T::Type) = ArnoldiDecomp{T}(
123-
A,
124-
zeros(T, size(A, 1), order + 1),
125-
zeros(T, order + 1, order)
103+
function gmres_iterable!(x, A, b;
104+
Pl = Identity(),
105+
Pr = Identity(),
106+
tol = sqrt(eps(real(eltype(b)))),
107+
restart::Int = min(20, length(b)),
108+
maxiter::Int = restart,
109+
initially_zero = false
126110
)
111+
T = eltype(b)
127112

128-
type Residual{numT, resT}
129-
current::resT # Current relative residual
130-
accumulator::resT # Used to compute the residual on the go
131-
nullvec::Vector{numT} # Vector in the null space of H to compute residuals
132-
β::resT # the initial residual
113+
# Approximate solution
114+
arnoldi = ArnoldiDecomp(A, restart, T)
115+
residual = Residual(restart, T)
116+
mv_products = initially_zero == true ? 1 : 0
117+
118+
# Workspace vector to reduce the # allocs.
119+
Ax = similar(b)
120+
residual.current = init!(arnoldi, x, b, Pl, Ax, initially_zero = initially_zero)
121+
init_residual!(residual, residual.current)
122+
123+
reltol = tol * residual.current
124+
125+
GMRESIterable(Pl, Pr, x, b, Ax,
126+
arnoldi, residual,
127+
mv_products, restart, 1, maxiter, reltol, residual.current
128+
)
133129
end
134130

135-
Residual(order, T::Type) = Residual{T, real(T)}(
136-
one(real(T)),
137-
one(real(T)),
138-
ones(T, order + 1),
139-
one(real(T))
131+
gmres(A, b; kwargs...) = gmres!(zeros(b), A, b; initially_zero = true, kwargs...)
132+
133+
function gmres!(x, A, b;
134+
Pl = Identity(),
135+
Pr = Identity(),
136+
tol = sqrt(eps(real(eltype(b)))),
137+
restart::Int = min(20, length(b)),
138+
maxiter::Int = restart,
139+
log::Bool = false,
140+
initially_zero = false,
141+
verbose::Bool = false
140142
)
143+
history = ConvergenceHistory(partial = !log, restart = restart)
144+
history[:tol] = tol
145+
log && reserve!(history, :resnorm, maxiter)
146+
147+
iterable = gmres_iterable!(x, A, b; Pl = Pl, Pr = Pr, tol = tol, maxiter = maxiter, restart = restart, initially_zero = initially_zero)
148+
149+
verbose && @printf("=== gmres ===\n%4s\t%4s\t%7s\n","rest","iter","resnorm")
150+
151+
for (iteration, residual) = enumerate(iterable)
152+
if log
153+
nextiter!(history)
154+
history.mvps = iterable.mv_products
155+
push!(history, :resnorm, residual)
156+
end
157+
158+
verbose && @printf("%3d\t%3d\t%1.2e\n", 1 + div(iteration - 1, restart), 1 + mod(iteration - 1, restart), residual)
159+
end
160+
161+
verbose && println()
162+
setconv(history, converged(iterable))
163+
log && shrink!(history)
164+
165+
log ? (x, history) : x
166+
end
141167

142168
function update_residual!(r::Residual, arnoldi::ArnoldiDecomp, k::Int)
143169
# Cheaply computes the current residual
@@ -146,15 +172,20 @@ function update_residual!(r::Residual, arnoldi::ArnoldiDecomp, k::Int)
146172
r.current = r.β / r.accumulator
147173
end
148174

149-
function init!{T}(arnoldi::ArnoldiDecomp{T}, x, b, Pl, reserved_vec)
175+
function init!{T}(arnoldi::ArnoldiDecomp{T}, x, b, Pl, Ax; initially_zero::Bool = false)
150176
# Initialize the Krylov subspace with the initial residual vector
151177
# This basically does V[1] = Pl \ (b - A * x) and then normalize
152178

153179
first_col = view(arnoldi.V, :, 1)
154180

155181
copy!(first_col, b)
156-
A_mul_B!(reserved_vec, arnoldi.A, x)
157-
@blas! first_col -= one(T) * reserved_vec
182+
183+
# Potentially save one MV product
184+
if !initially_zero
185+
A_mul_B!(Ax, arnoldi.A, x)
186+
@blas! first_col -= one(T) * Ax
187+
end
188+
158189
A_ldiv_B!(Pl, first_col)
159190

160191
# Normalize
@@ -179,33 +210,37 @@ function solve_least_squares!{T}(arnoldi::ArnoldiDecomp{T}, β, k::Int)
179210
rhs
180211
end
181212

182-
function update_solution!{T}(x, y, arnoldi::ArnoldiDecomp{T}, Pr::Identity, k::Int)
213+
function update_solution!{T}(x, y, arnoldi::ArnoldiDecomp{T}, Pr::Identity, k::Int, Ax)
183214
# Update x ← x + V * y
184215

185216
# TODO: find the SugarBLAS alternative
186217
BLAS.gemv!('N', one(T), view(arnoldi.V, :, 1 : k - 1), y, one(T), x)
187218
end
188219

189-
function update_solution!{T}(x, y, arnoldi::ArnoldiDecomp{T}, Pr, k::Int)
190-
# Allocates a temporary while computing x ← x + Pr \ (V * y)
191-
tmp = view(arnoldi.V, :, 1 : k - 1) * y
192-
@blas! x += one(T) * (Pr \ tmp)
220+
function update_solution!{T}(x, y, arnoldi::ArnoldiDecomp{T}, Pr, k::Int, Ax)
221+
# Computing x ← x + Pr \ (V * y) and use Ax as a work space
222+
A_mul_B!(Ax, view(arnoldi.V, :, 1 : k - 1), y)
223+
A_ldiv_B!(Pr, Ax)
224+
@blas! x += one(T) * Ax
193225
end
194226

195-
function expand!(arnoldi::ArnoldiDecomp, Pl::Identity, Pr::Identity, k::Int)
227+
function expand!(arnoldi::ArnoldiDecomp, Pl::Identity, Pr::Identity, k::Int, Ax)
196228
# Simply expands by A * v without allocating
197229
A_mul_B!(view(arnoldi.V, :, k + 1), arnoldi.A, view(arnoldi.V, :, k))
198230
end
199231

200-
function expand!(arnoldi::ArnoldiDecomp, Pl, Pr::Identity, k::Int)
232+
function expand!(arnoldi::ArnoldiDecomp, Pl, Pr::Identity, k::Int, Ax)
201233
# Expands by Pl \ (A * v) without allocating
202-
A_mul_B!(view(arnoldi.V, :, k + 1), arnoldi.A, view(arnoldi.V, :, k))
203-
A_ldiv_B!(Pl, view(arnoldi.V, :, k + 1))
234+
nextV = view(arnoldi.V, :, k + 1)
235+
A_mul_B!(nextV, arnoldi.A, view(arnoldi.V, :, k))
236+
A_ldiv_B!(Pl, nextV)
204237
end
205238

206-
function expand!(arnoldi::ArnoldiDecomp, Pl, Pr, k::Int)
207-
# Expands by Pl \ (A * (Pr \ v)). Allocates one vector.
208-
A_ldiv_B!(view(arnoldi.V, :, k + 1), Pr, view(arnoldi.V, :, k))
209-
copy!(view(arnoldi.V, :, k + 1), arnoldi.A * view(arnoldi.V, :, k + 1))
210-
A_ldiv_B!(Pl, view(arnoldi.V, :, k + 1))
239+
function expand!(arnoldi::ArnoldiDecomp, Pl, Pr, k::Int, Ax)
240+
# Expands by Pl \ (A * (Pr \ v)). Avoids allocation by using Ax.
241+
nextV = view(arnoldi.V, :, k + 1)
242+
A_ldiv_B!(nextV, Pr, view(arnoldi.V, :, k))
243+
A_mul_B!(Ax, arnoldi.A, nextV)
244+
copy!(nextV, Ax)
245+
A_ldiv_B!(Pl, nextV)
211246
end

0 commit comments

Comments
 (0)