Skip to content

Commit 8c76c51

Browse files
haampieandreasnoack
authored andcommitted
Fix CG without preconditioning (#133)
* Add Identity preconditioner type * Specialize for the case where no preconditioner is given * Revert simd loop to 2 BLAS-1 ops * Export Identity
1 parent 81e7872 commit 8c76c51

File tree

1 file changed

+105
-26
lines changed

1 file changed

+105
-26
lines changed

src/cg.jl

Lines changed: 105 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,80 +7,159 @@ function cg!(x, A, b;
77
maxiter::Integer = min(20, size(A, 1)),
88
plot = false,
99
log::Bool = false,
10+
Pl = Identity(),
1011
kwargs...
1112
)
1213
(plot & !log) && error("Can't plot when log keyword is false")
13-
K = KrylovSubspace(A, length(b), 1, Vector{Adivtype(A,b)}[])
14-
init!(K, x)
1514
history = ConvergenceHistory(partial = !log)
1615
history[:tol] = tol
17-
reserve!(history, :resnorm, maxiter)
18-
cg_method!(history, x, K, b; tol = tol, maxiter = maxiter, kwargs...)
19-
(plot || log) && shrink!(history)
16+
log && reserve!(history, :resnorm, maxiter + 1)
17+
cg_method!(history, x, A, b, Pl; tol = tol, log = log, maxiter = maxiter, kwargs...)
18+
log && shrink!(history)
2019
plot && showplot(history)
2120
log ? (x, history) : x
2221
end
2322

24-
function cg_method!(log::ConvergenceHistory, x, K, b;
25-
Pl = 1,
23+
function cg_method!(history::ConvergenceHistory, x, A, b, Pl;
2624
tol = sqrt(eps(real(eltype(b)))),
27-
maxiter::Integer = min(20, size(K.A, 1)),
28-
verbose::Bool = false
25+
maxiter::Integer = min(20, size(A, 1)),
26+
verbose::Bool = false,
27+
log = false
2928
)
29+
# Preconditioned CG
3030
T = eltype(b)
31-
n = size(K.A, 1)
31+
n = size(A, 1)
3232

3333
# Initial residual vector
3434
r = copy(b)
35-
@blas! r -= one(T) * nextvec(K)
36-
c = zeros(T, n)
35+
c = A * x
36+
@blas! r -= one(T) * c
3737
u = zeros(T, n)
3838
ρ = one(T)
39-
40-
iter = 0
4139

42-
last_residual = norm(r)
40+
if log
41+
history.mvps += 1
42+
end
43+
44+
iter = 1
4345

4446
# Here you could save one inner product if norm(r) is used rather than norm(b)
4547
reltol = norm(b) * tol
48+
last_residual = zero(T)
4649

47-
while last_residual > reltol && iter < maxiter
48-
nextiter!(log, mvps = 1)
50+
while true
51+
52+
last_residual = norm(r)
53+
54+
verbose && @printf("%3d\t%1.2e\n", iter, last_residual)
55+
56+
if last_residual reltol || iter > maxiter
57+
break
58+
end
59+
60+
# Log progress
61+
if log
62+
nextiter!(history, mvps = 1)
63+
push!(history, :resnorm, last_residual)
64+
end
4965

5066
# Preconditioner: c = Pl \ r
5167
solve!(c, Pl, r)
5268

5369
ρ_prev = ρ
5470
ρ = dot(c, r)
55-
β = -ρ / ρ_prev
71+
β = ρ / ρ_prev
5672

57-
# u := r - βu (almost an axpy)
58-
@blas! u *= -β
73+
# u := c + βu (almost an axpy)
74+
@blas! u *= β
5975
@blas! u += one(T) * c
6076

6177
# c = A * u
62-
append!(K, u)
63-
nextvec!(c, K)
78+
A_mul_B!(c, A, u)
6479
α = ρ / dot(u, c)
6580

6681
# Improve solution and residual
6782
@blas! x += α * u
6883
@blas! r -= α * c
6984

7085
iter += 1
71-
last_residual = norm(r)
86+
end
87+
88+
verbose && @printf("\n")
89+
log && setconv(history, last_residual < reltol)
90+
91+
x
92+
end
93+
94+
function cg_method!(history::ConvergenceHistory, x, A, b, Pl::Identity;
95+
tol = sqrt(eps(real(eltype(b)))),
96+
maxiter::Integer = min(20, size(A, 1)),
97+
verbose::Bool = false,
98+
log = false
99+
)
100+
# Unpreconditioned CG
101+
T = eltype(b)
102+
n = size(A, 1)
103+
104+
# Initial residual vector
105+
r = copy(b)
106+
c = A * x
107+
@blas! r -= one(T) * c
108+
u = zeros(T, n)
109+
ρ = one(T)
110+
111+
if log
112+
history.mvps += 1
113+
end
114+
115+
iter = 1
116+
117+
reltol = norm(b) * tol
118+
last_residual = zero(T)
119+
120+
while true
121+
122+
ρ_prev = ρ
123+
ρ = dot(r, r)
124+
β = ρ / ρ_prev
125+
126+
last_residual = sqrt(ρ)
72127

73128
# Log progress
74-
push!(log, :resnorm, last_residual)
129+
if log
130+
nextiter!(history, mvps = 1)
131+
push!(history, :resnorm, last_residual)
132+
end
133+
75134
verbose && @printf("%3d\t%1.2e\n", iter, last_residual)
135+
136+
# Stopping condition
137+
if last_residual reltol || iter > maxiter
138+
break
139+
end
140+
141+
# u := r + βu (almost an axpy)
142+
@blas! u *= β
143+
@blas! u += one(T) * r
144+
145+
# c = A * u
146+
A_mul_B!(c, A, u)
147+
α = ρ / dot(u, c)
148+
149+
# Improve solution and residual
150+
@blas! x += α * u
151+
@blas! r -= α * c
152+
153+
iter += 1
76154
end
77155

78156
verbose && @printf("\n")
79-
setconv(log, last_residual < reltol)
157+
log && setconv(history, last_residual < reltol)
80158

81159
x
82160
end
83161

162+
84163
#################
85164
# Documentation #
86165
#################
@@ -126,7 +205,7 @@ $arg
126205
127206
## Keywords
128207
129-
`Pl = 1`: left preconditioner of the method.
208+
`Pl = Identity()`: left preconditioner of the method.
130209
131210
`tol::Real = size(A,2)*eps()`: stopping tolerance.
132211

0 commit comments

Comments
 (0)