Skip to content

Commit ae944bd

Browse files
haampieandreasnoack
authored andcommitted
MINRES (#148)
* Add MINRES * Add tests and fix a bug for complex arithmetic * Add documentation * Add support for skew-Hermitian matrices
1 parent 1c6085f commit ae944bd

File tree

5 files changed

+359
-0
lines changed

5 files changed

+359
-0
lines changed

benchmark/benchmark-linear-systems.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,12 @@ function bicgstabl()
7676
b1, b2, b3, b4
7777
end
7878

79+
function minres(n = 100_000)
80+
A = SymTridiagonal(fill(2.1, n), fill(-1.0, n))
81+
x = ones(n)
82+
b = A * x
83+
84+
@benchmark IterativeSolvers.minres($A, $b, maxiter = 100)
85+
end
86+
7987
end

src/IterativeSolvers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include("hessenberg.jl")
2121
#Linear solvers
2222
include("stationary.jl")
2323
include("cg.jl")
24+
include("minres.jl")
2425
include("bicgstabl.jl")
2526
include("gmres.jl")
2627
include("chebyshev.jl")

src/minres.jl

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
export minres_iterable, minres
2+
3+
import Base.LinAlg: BLAS.axpy!, givensAlgorithm
4+
import Base: start, next, done
5+
6+
"""
7+
MINRES is full GMRES for Hermitian matrices, finding
8+
xₖ := x₀ + Vₖyₖ, where Vₖ is an orthonormal basis for
9+
the Krylov subspace K(A, b - Ax₀). Since
10+
|rₖ| = |b - Axₖ| = |r₀ - Vₖ₊₁Hₖyₖ| = | |r₀|e₁ - Hₖyₖ|, we solve
11+
Hₖyₖ = |r₀|e₁ in least-square sense. As the Hessenberg matrix
12+
is tridiagonal, its QR decomp Hₖ = QₖRₖ has Rₖ with only 3 diagonals,
13+
easily obtained with Givens rotations. The least-squares problem is
14+
solved as Hₖ'Hₖyₖ = Hₖ'|r₀|e₁ => yₖ = inv(Rₖ) Qₖ'|r₀|e₁, so
15+
xₖ := x₀ + [Vₖ inv(Rₖ)] [Qₖ'|r₀|e₁]. Now the main difference with GMRES
16+
is the placement of the brackets. MINRES computes Wₖ = Vₖ inv(Rₖ) via 3-term
17+
recurrences using only the last column of R, and computes last
18+
two terms of Qₖ'|r₀|e₁ as well.
19+
The active column of the Hessenberg matrix H is updated in place to form the
20+
active column of R. Note that for Hermitian matrices the H matrix is purely
21+
real, while for skew-Hermitian matrices its diagonal is purely imaginary.
22+
"""
23+
type MINRESIterable{matT, vecT <: DenseVector, smallVecT <: DenseVector, rotT <: Number, realT <: Real}
24+
A::matT
25+
skew_hermitian::Bool
26+
x::vecT
27+
28+
# Krylov basis vectors
29+
v_prev::vecT
30+
v_curr::vecT
31+
v_next::vecT
32+
33+
# W = R * inv(V) is computed using 3-term recurrence
34+
w_prev::vecT
35+
w_curr::vecT
36+
w_next::vecT
37+
38+
# Vector of size 4, holding the active column of the Hessenberg matrix
39+
# rhs is just two active values of the right-hand side.
40+
H::smallVecT
41+
rhs::smallVecT
42+
43+
# Some Givens rotations
44+
c_prev::rotT
45+
s_prev::rotT
46+
c_curr::rotT
47+
s_curr::rotT
48+
49+
# Bookkeeping
50+
mv_products::Int
51+
maxiter::Int
52+
tolerance::realT
53+
resnorm::realT
54+
end
55+
56+
minres_iterable(A, b; kwargs...) = minres_iterable!(zerox(A, b), A, b; initially_zero = true, kwargs...)
57+
58+
function minres_iterable!(x, A, b;
59+
initially_zero::Bool = false,
60+
skew_hermitian::Bool = false,
61+
tol = sqrt(eps(real(eltype(b)))),
62+
maxiter = size(A, 1)
63+
)
64+
T = eltype(b)
65+
HessenbergT = skew_hermitian ? T : real(T)
66+
67+
v_prev = similar(b)
68+
v_curr = copy(b)
69+
v_next = similar(b)
70+
w_prev = similar(b)
71+
w_curr = similar(b)
72+
w_next = similar(b)
73+
74+
mv_products = 0
75+
76+
# For nonzero x's, we must do an MV for the initial residual vec
77+
if !initially_zero
78+
# Use v_next to store Ax; v_next will soon be overwritten.
79+
A_mul_B!(v_next, A, x)
80+
axpy!(-one(T), v_next, v_curr)
81+
mv_products = 1
82+
end
83+
84+
resnorm = norm(v_curr)
85+
reltol = resnorm * tol
86+
87+
# Last active column of the Hessenberg matrix
88+
# and last two entries of the right-hand side
89+
H = zeros(HessenbergT, 4)
90+
rhs = [resnorm; zero(HessenbergT)]
91+
92+
# Normalize the first Krylov basis vector
93+
scale!(v_curr, inv(resnorm))
94+
95+
# Givens rotations
96+
c_prev, s_prev = one(T), zero(T)
97+
c_curr, s_curr = one(T), zero(T)
98+
99+
MINRESIterable(
100+
A, skew_hermitian, x,
101+
v_prev, v_curr, v_next,
102+
w_prev, w_curr, w_next,
103+
H, rhs,
104+
c_prev, s_prev, c_curr, s_curr,
105+
mv_products, maxiter, reltol, resnorm
106+
)
107+
end
108+
109+
converged(m::MINRESIterable) = m.resnorm m.tolerance
110+
111+
start(::MINRESIterable) = 1
112+
113+
done(m::MINRESIterable, iteration::Int) = iteration > m.maxiter || converged(m)
114+
115+
function next(m::MINRESIterable, iteration::Int)
116+
# v_next = A * v_curr - H[2] * v_prev
117+
A_mul_B!(m.v_next, m.A, m.v_curr)
118+
119+
iteration > 1 && axpy!(-m.H[2], m.v_prev, m.v_next)
120+
121+
# Orthogonalize w.r.t. v_curr
122+
proj = dot(m.v_curr, m.v_next)
123+
m.H[3] = m.skew_hermitian ? proj : real(proj)
124+
axpy!(-proj, m.v_curr, m.v_next)
125+
126+
# Normalize
127+
m.H[4] = norm(m.v_next)
128+
scale!(m.v_next, inv(m.H[4]))
129+
130+
# Rotation on H[1] and H[2]. Note that H[1] = 0 initially
131+
if iteration > 2
132+
m.H[1] = m.s_prev * m.H[2]
133+
m.H[2] = m.c_prev * m.H[2]
134+
end
135+
136+
# Rotation on H[2] and H[3]
137+
if iteration > 1
138+
tmp = -conj(m.s_curr) * m.H[2] + m.c_curr * m.H[3]
139+
m.H[2] = m.c_curr * m.H[2] + m.s_curr * m.H[3]
140+
m.H[3] = tmp
141+
end
142+
143+
# Next rotation
144+
c, s, m.H[3] = givensAlgorithm(m.H[3], m.H[4])
145+
146+
# Apply as well to the right-hand side
147+
m.rhs[2] = -conj(s) * m.rhs[1]
148+
m.rhs[1] = c * m.rhs[1]
149+
150+
# Update W = V * inv(R). Two axpy's can maybe be one MV.
151+
copy!(m.w_next, m.v_curr)
152+
iteration > 1 && axpy!(-m.H[2], m.w_curr, m.w_next)
153+
iteration > 2 && axpy!(-m.H[1], m.w_prev, m.w_next)
154+
scale!(m.w_next, inv(m.H[3]))
155+
156+
# Update solution x
157+
axpy!(m.rhs[1], m.w_next, m.x)
158+
159+
# Move on: next -> curr, curr -> prev
160+
m.v_prev, m.v_curr, m.v_next = m.v_curr, m.v_next, m.v_prev
161+
m.w_prev, m.w_curr, m.w_next = m.w_curr, m.w_next, m.w_prev
162+
m.c_prev, m.s_prev, m.c_curr, m.s_curr = m.c_curr, m.s_curr, c, s
163+
m.rhs[1] = m.rhs[2]
164+
165+
# Due to symmetry of the tri-diagonal matrix
166+
m.H[2] = m.skew_hermitian ? -m.H[4] : m.H[4]
167+
168+
# The approximate residual is cheaply available
169+
m.resnorm = abs(m.rhs[2])
170+
171+
m.resnorm, iteration + 1
172+
end
173+
174+
function minres!(x, A, b;
175+
skew_hermitian::Bool = false,
176+
verbose::Bool = false,
177+
log::Bool = false,
178+
tol = sqrt(eps(real(eltype(b)))),
179+
maxiter::Int = min(30, size(A, 1)),
180+
initially_zero::Bool = false
181+
)
182+
history = ConvergenceHistory(partial = !log)
183+
history[:tol] = tol
184+
log && reserve!(history, :resnorm, maxiter)
185+
186+
iterable = minres_iterable!(x, A, b;
187+
skew_hermitian = skew_hermitian,
188+
tol = tol,
189+
maxiter = maxiter,
190+
initially_zero = initially_zero
191+
)
192+
193+
if log
194+
history.mvps = iterable.mv_products
195+
end
196+
197+
for (iteration, resnorm) = enumerate(iterable)
198+
if log
199+
nextiter!(history, mvps = 1)
200+
push!(history, :resnorm, resnorm)
201+
end
202+
verbose && @printf("%3d\t%1.2e\n", iteration, resnorm)
203+
end
204+
205+
verbose && println()
206+
log && setconv(history, converged(iterable))
207+
log && shrink!(history)
208+
209+
log ? (iterable.x, history) : iterable.x
210+
end
211+
212+
minres(A, b; kwargs...) = minres!(zerox(A, b), A, b; initially_zero = true, kwargs...)
213+
214+
#################
215+
# Documentation #
216+
#################
217+
218+
let
219+
doc_call = "minres(A, b)"
220+
doc!_call = "minres!(x, A, b)"
221+
222+
doc_msg = "Using initial guess zeros(b)."
223+
doc!_msg = "Overwrites `x`."
224+
225+
doc_arg = ""
226+
doc!_arg = """`x`: initial guess, overwrite final approximation."""
227+
228+
doc_version = (doc_call, doc_msg, doc_arg)
229+
doc!_version = (doc!_call, doc!_msg, doc!_arg)
230+
231+
docstring = String[]
232+
233+
#Build docs
234+
for (call, msg, arg) in (doc_version, doc!_version) #Start
235+
push!(docstring,
236+
"""
237+
$call
238+
239+
Solve A*x = b for (skew-)Hermitian matrices A using MINRES. The method is mathematically
240+
equivalent to unrestarted GMRES, but exploits symmetry of A, resulting in short
241+
recurrences requiring only 6 vectors of storage. MINRES might be slightly less
242+
stable than full GMRES.
243+
244+
$msg
245+
246+
# Arguments
247+
248+
$arg
249+
250+
`A`: linear operator.
251+
252+
`b`: right hand side (vector).
253+
254+
## Keywords
255+
256+
`tol::Real = sqrt(eps(real(eltype(b))))`: tolerance for stopping condition
257+
`|r_k| / |r_0| ≤ tol`. Note that the residual is computed only approximately.
258+
259+
`maxiter::Int = min(30, size(A, 1))`: maximum number of iterations.
260+
261+
`verbose::Bool = false` output during the iterations
262+
263+
`log::Bool = false` enables logging, see **Output**.
264+
265+
# Output
266+
267+
**if `log` is `false`**
268+
269+
`x`: approximated solution.
270+
271+
**if `log` is `true`**
272+
273+
`x`: approximated solution.
274+
275+
`ch`: convergence history.
276+
"""
277+
)
278+
end
279+
280+
@doc docstring[1] -> minres
281+
@doc docstring[2] -> minres!
282+
end

test/minres.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using IterativeSolvers
2+
using FactCheck
3+
using Base.Test
4+
using LinearMaps
5+
6+
srand(123)
7+
8+
facts("MINRES") do
9+
10+
function hermitian_problem(T, n)
11+
B = rand(T, n, n) + n * eye(n)
12+
A = B + B'
13+
x = ones(T, n)
14+
b = B * x
15+
A, x, b
16+
end
17+
18+
function skew_hermitian_problem(T, n)
19+
B = rand(T, n, n) + n * eye(n)
20+
A = B - B'
21+
x = ones(T, n)
22+
b = A * x
23+
A, x, b
24+
end
25+
26+
for T in (Float32, Float64, Complex64, Complex128)
27+
n = 15
28+
29+
context("Hermitian Matrix{$T}") do
30+
A, x, b = hermitian_problem(T, n)
31+
tol = sqrt(eps(real(T)))
32+
33+
x_approx, hist = minres(A, b, maxiter = 10n, tol = tol, log = true)
34+
35+
@fact norm(b - A * x_approx) / norm(b) --> less_than_or_equal(tol)
36+
@fact hist.isconverged --> true
37+
end
38+
39+
context("Skew-Hermitian Matrix{$T}") do
40+
A, x, b = skew_hermitian_problem(T, n)
41+
tol = sqrt(eps(real(T)))
42+
x_approx, hist = minres(A, b, skew_hermitian = true, maxiter = 10n, tol = tol, log = true)
43+
44+
@fact norm(b - A * x_approx) / norm(b) --> less_than_or_equal(tol)
45+
@fact hist.isconverged --> true
46+
end
47+
48+
context("SparseMatrixCSC{$T}") do
49+
A = let
50+
B = sprand(n, n, 2 / n)
51+
B + B' + speye(n)
52+
end
53+
54+
x = ones(T, n)
55+
b = A * x
56+
tol = sqrt(eps(real(T)))
57+
58+
x_approx, hist = minres(A, b, maxiter = 10n, tol = tol, log = true)
59+
60+
@fact norm(b - A * x_approx) / norm(b) --> less_than_or_equal(tol)
61+
@fact hist.isconverged --> true
62+
end
63+
end
64+
65+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ include("cg.jl")
1616
#BiCGStab(l)
1717
include("bicgstabl.jl")
1818

19+
#MINRES
20+
include("minres.jl")
21+
1922
#GMRES
2023
include("gmres.jl")
2124

0 commit comments

Comments
 (0)