Skip to content

Commit 152f89f

Browse files
committed
abstol, reltol for CG
1 parent a78f25e commit 152f89f

File tree

2 files changed

+90
-42
lines changed

2 files changed

+90
-42
lines changed

src/cg.jl

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mutable struct CGIterable{matT, solT, vecT, numT <: Real}
88
r::vecT
99
c::vecT
1010
u::vecT
11-
reltol::numT
11+
tol::numT
1212
residual::numT
1313
prev_residual::numT
1414
maxiter::Int
@@ -22,14 +22,14 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb
2222
r::vecT
2323
c::vecT
2424
u::vecT
25-
reltol::numT
25+
tol::numT
2626
residual::numT
2727
ρ::paramT
2828
maxiter::Int
2929
mv_products::Int
3030
end
3131

32-
@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual it.reltol
32+
@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual it.tol
3333

3434
@inline start(it::Union{CGIterable, PCGIterable}) = 0
3535

@@ -41,7 +41,10 @@ end
4141
###############
4242

4343
function iterate(it::CGIterable, iteration::Int=start(it))
44-
if done(it, iteration) return nothing end
44+
# Check for termination first
45+
if done(it, iteration)
46+
return nothing
47+
end
4548

4649
# u := r + βu (almost an axpy)
4750
β = it.residual^2 / it.prev_residual^2
@@ -72,6 +75,7 @@ function iterate(it::PCGIterable, iteration::Int=start(it))
7275
return nothing
7376
end
7477

78+
# Apply left preconditioner
7579
ldiv!(it.c, it.Pl, it.r)
7680

7781
ρ_prev = it.ρ
@@ -114,40 +118,48 @@ struct CGStateVariables{T,Tx<:AbstractArray{T}}
114118
end
115119

116120
function cg_iterator!(x, A, b, Pl = Identity();
117-
tol = sqrt(eps(real(eltype(b)))),
118-
maxiter::Int = size(A, 2),
119-
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
120-
initially_zero::Bool = false
121-
)
121+
abstol::Real = zero(real(eltype(b))),
122+
reltol::Real = sqrt(eps(real(eltype(b)))),
123+
tol = nothing, # TODO: Deprecations introduced in v0.8
124+
maxiter::Int = size(A, 2),
125+
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
126+
initially_zero::Bool = false)
122127
u = statevars.u
123128
r = statevars.r
124129
c = statevars.c
125130
u .= zero(eltype(x))
126131
copyto!(r, b)
127132

133+
# TODO: Deprecations introduced in v0.8
134+
if tol !== nothing
135+
Base.depwarn("The keyword argument `tol` is deprecated, use `reltol` instead.", :cg_iterator!)
136+
reltol = tol
137+
end
138+
128139
# Compute r with an MV-product or not.
129140
if initially_zero
130141
mv_products = 0
131-
c = similar(x)
132-
residual = norm(b)
133-
reltol = residual * tol # Save one dot product
134142
else
135143
mv_products = 1
136144
mul!(c, A, x)
137145
r .-= c
138-
residual = norm(r)
139-
reltol = norm(b) * tol
140146
end
147+
residual = norm(r)
148+
# TODO: According to the docs, the code below should use the initial residual
149+
# instead of the norm of the RHS `b` to set the relative tolerance.
150+
# See also https://github.com/JuliaMath/IterativeSolvers.jl/pull/244
151+
# tolerance = max(reltol * residual, abstol)
152+
tolerance = max(reltol * norm(b), abstol)
141153

142154
# Return the iterable
143155
if isa(Pl, Identity)
144156
return CGIterable(A, x, r, c, u,
145-
reltol, residual, one(residual),
157+
tolerance, residual, one(residual),
146158
maxiter, mv_products
147159
)
148160
else
149161
return PCGIterable(Pl, A, x, r, c, u,
150-
reltol, residual, one(eltype(x)),
162+
tolerance, residual, one(eltype(x)),
151163
maxiter, mv_products
152164
)
153165
end
@@ -199,20 +211,28 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...)
199211
- `:resnom` => `::Vector`: residual norm at each iteration.
200212
"""
201213
function cg!(x, A, b;
202-
tol = sqrt(eps(real(eltype(b)))),
203-
maxiter::Int = size(A, 2),
204-
log::Bool = false,
205-
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
206-
verbose::Bool = false,
207-
Pl = Identity(),
208-
kwargs...
209-
)
214+
abstol::Real = zero(real(eltype(b))),
215+
reltol::Real = sqrt(eps(real(eltype(b)))),
216+
tol = nothing, # TODO: Deprecations introduced in v0.8
217+
maxiter::Int = size(A, 2),
218+
log::Bool = false,
219+
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
220+
verbose::Bool = false,
221+
Pl = Identity(),
222+
kwargs...)
210223
history = ConvergenceHistory(partial = !log)
211224
history[:tol] = tol
212225
log && reserve!(history, :resnorm, maxiter + 1)
213226

227+
# TODO: Deprecations introduced in v0.8
228+
if tol !== nothing
229+
Base.depwarn("The keyword argument `tol` is deprecated, use `reltol` instead.", :cg!)
230+
reltol = tol
231+
end
232+
214233
# Actually perform CG
215-
iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, statevars = statevars, kwargs...)
234+
iterable = cg_iterator!(x, A, b, Pl; abstol = abstol, reltol = reltol, maxiter = maxiter,
235+
statevars = statevars, kwargs...)
216236
if log
217237
history.mvps = iterable.mv_products
218238
end

test/cg.jl

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ Random.seed!(1234321)
2626
A = rand(T, n, n)
2727
A = A' * A + I
2828
b = rand(T, n)
29-
tol = eps(real(T))
29+
reltol = eps(real(T))
3030

31-
x,ch = cg(A, b; tol=tol, maxiter=2n, log=true)
31+
x,ch = cg(A, b; reltol=reltol, maxiter=2n, log=true)
3232
@test isa(ch, ConvergenceHistory)
33-
@test norm(A*x - b) / norm(b) tol
33+
@test norm(A*x - b) / norm(b) reltol
3434
@test ch.isconverged
3535

3636
# If you start from the exact solution, you should converge immediately
37-
x,ch = cg!(A \ b, A, b; tol=10tol, log=true)
37+
x,ch = cg!(A \ b, A, b; reltol=10*reltol, log=true)
3838
@test niters(ch) 1
3939
@test nprods(ch) 2
4040

@@ -56,29 +56,29 @@ end
5656

5757
rhs = randn(size(A, 2))
5858
rmul!(rhs, inv(norm(rhs)))
59-
tol = 1e-5
59+
reltol = 1e-5
6060

6161
@testset "SparseMatrixCSC{$T, $Ti}" for T in (Float64, Float32), Ti in (Int64, Int32)
62-
xCG = cg(A, rhs; tol=tol, maxiter=100)
63-
xJAC = cg(A, rhs; Pl=P, tol=tol, maxiter=100)
64-
@test norm(A * xCG - rhs) tol
65-
@test norm(A * xJAC - rhs) tol
62+
xCG = cg(A, rhs; reltol=reltol, maxiter=100)
63+
xJAC = cg(A, rhs; Pl=P, reltol=reltol, maxiter=100)
64+
@test norm(A * xCG - rhs) reltol
65+
@test norm(A * xJAC - rhs) reltol
6666
end
6767

6868
Af = LinearMap(A)
6969
@testset "Function" begin
70-
xCG = cg(Af, rhs; tol=tol, maxiter=100)
71-
xJAC = cg(Af, rhs; Pl=P, tol=tol, maxiter=100)
72-
@test norm(A * xCG - rhs) tol
73-
@test norm(A * xJAC - rhs) tol
70+
xCG = cg(Af, rhs; reltol=reltol, maxiter=100)
71+
xJAC = cg(Af, rhs; Pl=P, reltol=reltol, maxiter=100)
72+
@test norm(A * xCG - rhs) reltol
73+
@test norm(A * xJAC - rhs) reltol
7474
end
7575

7676
@testset "Function with specified starting guess" begin
7777
x0 = randn(size(rhs))
78-
xCG, hCG = cg!(copy(x0), Af, rhs; tol=tol, maxiter=100, log=true)
79-
xJAC, hJAC = cg!(copy(x0), Af, rhs; Pl=P, tol=tol, maxiter=100, log=true)
80-
@test norm(A * xCG - rhs) tol
81-
@test norm(A * xJAC - rhs) tol
78+
xCG, hCG = cg!(copy(x0), Af, rhs; reltol=reltol, maxiter=100, log=true)
79+
xJAC, hJAC = cg!(copy(x0), Af, rhs; Pl=P, reltol=reltol, maxiter=100, log=true)
80+
@test norm(A * xCG - rhs) reltol
81+
@test norm(A * xJAC - rhs) reltol
8282
@test niters(hJAC) == niters(hCG)
8383
end
8484
end
@@ -92,4 +92,32 @@ end
9292
@test hist.isconverged
9393
end
9494

95+
@testset "Termination criterion" begin
96+
for T in (Float32, Float64, ComplexF32, ComplexF64)
97+
A = T[ 2 -1 0
98+
-1 2 -1
99+
0 -1 2]
100+
n = size(A, 2)
101+
b = ones(T, n)
102+
x0 = A \ b
103+
perturbation = T[(-1)^i for i in 1:n]
104+
105+
# If the initial residual is small and a small relative tolerance is used,
106+
# many iterations are necessary
107+
x = x0 + sqrt(eps(real(T))) * perturbation
108+
initial_residual = norm(A * x - b)
109+
x, ch = cg!(x, A, b, log=true)
110+
@test_broken 2 niters(ch) n
111+
# This test is currently broken since `norm(b)` is used in `cg_iterator!`
112+
# instead of the initial `residual` as described in the documentation.
113+
114+
# If the initial residual is small and a large absolute tolerance is used,
115+
# no iterations are necessary
116+
x = x0 + 10*sqrt(eps(real(T))) * perturbation
117+
initial_residual = norm(A * x - b)
118+
x, ch = cg!(x, A, b, abstol=2*initial_residual, reltol=zero(real(T)), log=true)
119+
@test niters(ch) == 0
120+
end
121+
end
122+
95123
end

0 commit comments

Comments
 (0)