Skip to content

Commit c0a9e44

Browse files
authored
Merge pull request #280 from ranocha/hr/abstol_reltol
Absolute and relative tolerance for linear solvers
2 parents dcba3f5 + 0b9e47b commit c0a9e44

File tree

17 files changed

+589
-228
lines changed

17 files changed

+589
-228
lines changed

src/bicgstabl.jl

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ mutable struct BiCGStabIterable{precT, matT, solT, vecT <: AbstractVector, small
1313

1414
max_mv_products::Int
1515
mv_products::Int
16-
reltol::realT
16+
tol::realT
1717
residual::realT
1818

1919
Pl::precT
@@ -25,11 +25,12 @@ mutable struct BiCGStabIterable{precT, matT, solT, vecT <: AbstractVector, small
2525
end
2626

2727
function bicgstabl_iterator!(x, A, b, l::Int = 2;
28-
Pl = Identity(),
29-
max_mv_products = size(A, 2),
30-
initial_zero = false,
31-
tol = sqrt(eps(real(eltype(b))))
32-
)
28+
Pl = Identity(),
29+
max_mv_products = size(A, 2),
30+
abstol::Real = zero(real(eltype(b))),
31+
reltol::Real = sqrt(eps(real(eltype(b)))),
32+
tol = nothing, # TODO: Deprecations introduced in v0.8
33+
initial_zero = false)
3334
T = eltype(x)
3435
n = size(A, 1)
3536
mv_products = 0
@@ -41,6 +42,12 @@ function bicgstabl_iterator!(x, A, b, l::Int = 2;
4142

4243
residual = view(rs, :, 1)
4344

45+
# TODO: Deprecations introduced in v0.8
46+
if tol !== nothing
47+
Base.depwarn("The keyword argument `tol` is deprecated, use `reltol` instead.", :bicgstabl_iterator!)
48+
reltol = tol
49+
end
50+
4451
# Compute the initial residual rs[:, 1] = b - A * x
4552
# Avoid computing A * 0.
4653
if initial_zero
@@ -62,17 +69,17 @@ function bicgstabl_iterator!(x, A, b, l::Int = 2;
6269
# For the least-squares problem
6370
M = zeros(T, l + 1, l + 1)
6471

65-
# Stopping condition based on relative tolerance.
66-
reltol = nrm * tol
72+
# Stopping condition based on absolute and relative tolerance.
73+
tolerance = max(reltol * nrm, abstol)
6774

6875
BiCGStabIterable(A, l, x, r_shadow, rs, us,
69-
max_mv_products, mv_products, reltol, nrm,
76+
max_mv_products, mv_products, tolerance, nrm,
7077
Pl,
7178
γ, ω, σ, M
7279
)
7380
end
7481

75-
@inline converged(it::BiCGStabIterable) = it.residual it.reltol
82+
@inline converged(it::BiCGStabIterable) = it.residual it.tol
7683
@inline start(::BiCGStabIterable) = 0
7784
@inline done(it::BiCGStabIterable, iteration::Int) = it.mv_products it.max_mv_products || converged(it)
7885

@@ -157,10 +164,16 @@ bicgstabl(A, b, l = 2; kwargs...) = bicgstabl!(zerox(A, b), A, b, l; initial_zer
157164
- `max_mv_products::Int = size(A, 2)`: maximum number of matrix vector products.
158165
For BiCGStab(l) this is a less dubious term than "number of iterations";
159166
- `Pl = Identity()`: left preconditioner of the method;
160-
- `tol::Real = sqrt(eps(real(eltype(b))))`: tolerance for stopping condition `|r_k| / |r_0| ≤ tol`.
161-
Note that (1) the true residual norm is never computed during the iterations,
162-
only an approximation; and (2) if a preconditioner is given, the stopping condition is based on the
163-
*preconditioned residual*.
167+
- `abstol::Real = zero(real(eltype(b)))`,
168+
`reltol::Real = sqrt(eps(real(eltype(b))))`: absolute and relative
169+
tolerance for the stopping condition
170+
`|r_k| / |r_0| ≤ max(reltol * resnorm, abstol)`, where `r_k = A * x_k - b`
171+
is the residual in the `k`th iteration;
172+
!!! note
173+
1. The true residual norm is never computed during the iterations,
174+
only an approximation;
175+
2. If a left preconditioner is given, the stopping condition is based on the
176+
*preconditioned residual*.
164177
165178
# Return values
166179
@@ -174,21 +187,31 @@ For BiCGStab(l) this is a less dubious term than "number of iterations";
174187
- `history`: convergence history.
175188
"""
176189
function bicgstabl!(x, A, b, l = 2;
177-
tol = sqrt(eps(real(eltype(b)))),
178-
max_mv_products::Int = size(A, 2),
179-
log::Bool = false,
180-
verbose::Bool = false,
181-
Pl = Identity(),
182-
kwargs...
183-
)
190+
abstol::Real = zero(real(eltype(b))),
191+
reltol::Real = sqrt(eps(real(eltype(b)))),
192+
tol = nothing, # TODO: Deprecations introduced in v0.8
193+
max_mv_products::Int = size(A, 2),
194+
log::Bool = false,
195+
verbose::Bool = false,
196+
Pl = Identity(),
197+
kwargs...)
184198
history = ConvergenceHistory(partial = !log)
185-
history[:tol] = tol
199+
history[:abstol] = abstol
200+
history[:reltol] = reltol
186201

187202
# This doesn't yet make sense: the number of iters is smaller.
188203
log && reserve!(history, :resnorm, max_mv_products)
189204

190-
# Actually perform CG
191-
iterable = bicgstabl_iterator!(x, A, b, l; Pl = Pl, tol = tol, max_mv_products = max_mv_products, kwargs...)
205+
# TODO: Deprecations introduced in v0.8
206+
if tol !== nothing
207+
Base.depwarn("The keyword argument `tol` is deprecated, use `reltol` instead.", :bicgstabl!)
208+
reltol = tol
209+
end
210+
211+
# Actually perform iterative solve
212+
iterable = bicgstabl_iterator!(x, A, b, l; Pl = Pl,
213+
abstol = abstol, reltol = reltol,
214+
max_mv_products = max_mv_products, kwargs...)
192215

193216
if log
194217
history.mvps = iterable.mv_products

src/cg.jl

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mutable struct CGIterable{matT, solT, vecT, numT <: Real}
88
r::vecT
99
c::vecT
1010
u::vecT
11-
reltol::numT
11+
tol::numT
1212
residual::numT
1313
prev_residual::numT
1414
maxiter::Int
@@ -22,14 +22,14 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb
2222
r::vecT
2323
c::vecT
2424
u::vecT
25-
reltol::numT
25+
tol::numT
2626
residual::numT
2727
ρ::paramT
2828
maxiter::Int
2929
mv_products::Int
3030
end
3131

32-
@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual it.reltol
32+
@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual it.tol
3333

3434
@inline start(it::Union{CGIterable, PCGIterable}) = 0
3535

@@ -41,7 +41,10 @@ end
4141
###############
4242

4343
function iterate(it::CGIterable, iteration::Int=start(it))
44-
if done(it, iteration) return nothing end
44+
# Check for termination first
45+
if done(it, iteration)
46+
return nothing
47+
end
4548

4649
# u := r + βu (almost an axpy)
4750
β = it.residual^2 / it.prev_residual^2
@@ -72,6 +75,7 @@ function iterate(it::PCGIterable, iteration::Int=start(it))
7275
return nothing
7376
end
7477

78+
# Apply left preconditioner
7579
ldiv!(it.c, it.Pl, it.r)
7680

7781
ρ_prev = it.ρ
@@ -114,40 +118,44 @@ struct CGStateVariables{T,Tx<:AbstractArray{T}}
114118
end
115119

116120
function cg_iterator!(x, A, b, Pl = Identity();
117-
tol = sqrt(eps(real(eltype(b)))),
118-
maxiter::Int = size(A, 2),
119-
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
120-
initially_zero::Bool = false
121-
)
121+
abstol::Real = zero(real(eltype(b))),
122+
reltol::Real = sqrt(eps(real(eltype(b)))),
123+
tol = nothing, # TODO: Deprecations introduced in v0.8
124+
maxiter::Int = size(A, 2),
125+
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
126+
initially_zero::Bool = false)
122127
u = statevars.u
123128
r = statevars.r
124129
c = statevars.c
125130
u .= zero(eltype(x))
126131
copyto!(r, b)
127132

133+
# TODO: Deprecations introduced in v0.8
134+
if tol !== nothing
135+
Base.depwarn("The keyword argument `tol` is deprecated, use `reltol` instead.", :cg_iterator!)
136+
reltol = tol
137+
end
138+
128139
# Compute r with an MV-product or not.
129140
if initially_zero
130141
mv_products = 0
131-
c = similar(x)
132-
residual = norm(b)
133-
reltol = residual * tol # Save one dot product
134142
else
135143
mv_products = 1
136144
mul!(c, A, x)
137145
r .-= c
138-
residual = norm(r)
139-
reltol = norm(b) * tol
140146
end
147+
residual = norm(r)
148+
tolerance = max(reltol * residual, abstol)
141149

142150
# Return the iterable
143151
if isa(Pl, Identity)
144152
return CGIterable(A, x, r, c, u,
145-
reltol, residual, one(residual),
153+
tolerance, residual, one(residual),
146154
maxiter, mv_products
147155
)
148156
else
149157
return PCGIterable(Pl, A, x, r, c, u,
150-
reltol, residual, one(eltype(x)),
158+
tolerance, residual, one(eltype(x)),
151159
maxiter, mv_products
152160
)
153161
end
@@ -177,7 +185,11 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...)
177185
residual vector;
178186
- `Pl = Identity()`: left preconditioner of the method. Should be symmetric,
179187
positive-definite like `A`;
180-
- `tol::Real = sqrt(eps(real(eltype(b))))`: tolerance for stopping condition `|r_k| / |r_0| ≤ tol`;
188+
- `abstol::Real = zero(real(eltype(b)))`,
189+
`reltol::Real = sqrt(eps(real(eltype(b))))`: absolute and relative
190+
tolerance for the stopping condition
191+
`|r_k| / |r_0| ≤ max(reltol * resnorm, abstol)`, where `r_k = A * x_k - b`
192+
is the residual in the `k`th iteration;
181193
- `maxiter::Int = size(A,2)`: maximum number of iterations;
182194
- `verbose::Bool = false`: print method information;
183195
- `log::Bool = false`: keep track of the residual norm in each iteration.
@@ -199,20 +211,29 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...)
199211
- `:resnom` => `::Vector`: residual norm at each iteration.
200212
"""
201213
function cg!(x, A, b;
202-
tol = sqrt(eps(real(eltype(b)))),
203-
maxiter::Int = size(A, 2),
204-
log::Bool = false,
205-
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
206-
verbose::Bool = false,
207-
Pl = Identity(),
208-
kwargs...
209-
)
214+
abstol::Real = zero(real(eltype(b))),
215+
reltol::Real = sqrt(eps(real(eltype(b)))),
216+
tol = nothing, # TODO: Deprecations introduced in v0.8
217+
maxiter::Int = size(A, 2),
218+
log::Bool = false,
219+
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
220+
verbose::Bool = false,
221+
Pl = Identity(),
222+
kwargs...)
210223
history = ConvergenceHistory(partial = !log)
211-
history[:tol] = tol
224+
history[:abstol] = abstol
225+
history[:reltol] = reltol
212226
log && reserve!(history, :resnorm, maxiter + 1)
213227

228+
# TODO: Deprecations introduced in v0.8
229+
if tol !== nothing
230+
Base.depwarn("The keyword argument `tol` is deprecated, use `reltol` instead.", :cg!)
231+
reltol = tol
232+
end
233+
214234
# Actually perform CG
215-
iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, statevars = statevars, kwargs...)
235+
iterable = cg_iterator!(x, A, b, Pl; abstol = abstol, reltol = reltol, maxiter = maxiter,
236+
statevars = statevars, kwargs...)
216237
if log
217238
history.mvps = iterable.mv_products
218239
end

0 commit comments

Comments
 (0)