Skip to content

Commit eb031c0

Browse files
haampielopezm94
authored andcommitted
Reduce memory usage in CG (#130)
* Reduce memory usage in CG * Bring back Krylov subspace again * Fix deprecation warning
1 parent ac7bbee commit eb031c0

File tree

3 files changed

+110
-49
lines changed

3 files changed

+110
-49
lines changed

benchmark/benchmark-linear-systems.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
module LinearSystemsBench
2+
3+
import Base.A_ldiv_B!, Base.\
4+
5+
using BenchmarkTools
6+
using IterativeSolvers
7+
8+
# A DiagonalMatrix that doesn't check whether it is singular in the \ op.
9+
immutable DiagonalPreconditioner{T}
10+
diag::Vector{T}
11+
end
12+
13+
function A_ldiv_B!{T}(y::AbstractVector{T}, A::DiagonalPreconditioner{T}, b::AbstractVector{T})
14+
for i = 1 : length(b)
15+
@inbounds y[i] = A.diag[i] \ b[i]
16+
end
17+
y
18+
end
19+
20+
(\)(D::DiagonalPreconditioner, b::AbstractVector) = D.diag .\ b
21+
22+
function posdef(n)
23+
A = SymTridiagonal(fill(2.01, n), fill(-1.0, n))
24+
b = A * ones(n)
25+
return A, b
26+
end
27+
28+
function cg(; n = 1_000_000, tol = 1e-6, maxiter::Int = 200)
29+
A, b = posdef(n)
30+
P = DiagonalPreconditioner(collect(linspace(1.0, 2.0, n)))
31+
32+
println("Symmetric positive definite matrix of size ", n)
33+
println("Eigenvalues in interval [0.01, 4.01]")
34+
println("Tolerance = ", tol, "; max #iterations = ", maxiter)
35+
36+
# Dry run
37+
initial = rand(n)
38+
IterativeSolvers.cg!(copy(initial), A, b, Pl = P, maxiter = maxiter, tol = tol, log = false)
39+
40+
# Actual benchmark
41+
@benchmark IterativeSolvers.cg!(x0, $A, $b, Pl = $P, maxiter = $maxiter, tol = $tol, log = false) setup=(x0 = copy($initial))
42+
end
43+
44+
end

src/cg.jl

Lines changed: 65 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,83 @@
11
export cg, cg!
22

3-
####################
4-
# API method calls #
5-
####################
6-
7-
cg(A, b; kwargs...) = cg!(zerox(A,b), A, b; kwargs...)
3+
cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; kwargs...)
84

95
function cg!(x, A, b;
10-
tol::Real=size(A,2)*eps(), maxiter::Integer=size(A,2),
11-
plot=false, log::Bool=false, kwargs...
12-
)
6+
tol = sqrt(eps(real(eltype(b)))),
7+
maxiter::Integer = min(20, size(A, 1)),
8+
plot = false,
9+
log::Bool = false,
10+
kwargs...
11+
)
1312
(plot & !log) && error("Can't plot when log keyword is false")
1413
K = KrylovSubspace(A, length(b), 1, Vector{Adivtype(A,b)}[])
1514
init!(K, x)
16-
history = ConvergenceHistory(partial=!log)
15+
history = ConvergenceHistory(partial = !log)
1716
history[:tol] = tol
18-
if all(el->el==0, b)
19-
fill!(x, zero(eltype(x)))
20-
return log ? (x, history) : x
21-
end
22-
reserve!(history,:resnorm, maxiter)
23-
cg_method!(history, x, K, b; tol=tol, maxiter=maxiter, kwargs...)
17+
reserve!(history, :resnorm, maxiter)
18+
cg_method!(history, x, K, b; tol = tol, maxiter = maxiter, kwargs...)
2419
(plot || log) && shrink!(history)
2520
plot && showplot(history)
2621
log ? (x, history) : x
2722
end
2823

29-
#########################
30-
# Method Implementation #
31-
#########################
32-
3324
function cg_method!(log::ConvergenceHistory, x, K, b;
34-
Pl=1,tol::Real=size(K.A,2)*eps(),maxiter::Integer=size(K.A,2), verbose::Bool=false
35-
)
36-
verbose && @printf("=== cg ===\n%4s\t%7s\n","iter","resnorm")
37-
tol = tol * norm(b)
38-
r = b - nextvec(K)
39-
q = zeros(r)
40-
z = solve(Pl,r)
41-
p = copy(z)
42-
γ = dot(r, z)
43-
for iter=1:maxiter
44-
nextiter!(log, mvps=1)
45-
append!(K, p)
46-
nextvec!(q, K)
47-
α = γ/dot(p, q)
48-
# α>=0 || throw(PosSemidefException("α=$α"))
49-
@blas! x += α*p
50-
@blas! r += -α*q
51-
resnorm = norm(r)
52-
push!(log,:resnorm,resnorm)
53-
verbose && @printf("%3d\t%1.2e\n",iter,resnorm)
54-
resnorm < tol && break
55-
solve!(z,Pl,r)
56-
oldγ = γ
57-
γ = dot(r, z)
58-
β = γ/oldγ
59-
@blas! p *= β
60-
@blas! p += z
25+
Pl = 1,
26+
tol = sqrt(eps(real(eltype(b)))),
27+
maxiter::Integer = min(20, size(K.A, 1)),
28+
verbose::Bool = false
29+
)
30+
T = eltype(b)
31+
n = size(K.A, 1)
32+
33+
# Initial residual vector
34+
r = copy(b)
35+
@blas! r -= one(T) * nextvec(K)
36+
c = zeros(T, n)
37+
u = zeros(T, n)
38+
ρ = one(T)
39+
40+
iter = 0
41+
42+
last_residual = norm(r)
43+
44+
# Here you could save one inner product if norm(r) is used rather than norm(b)
45+
reltol = norm(b) * tol
46+
47+
while last_residual > reltol && iter < maxiter
48+
nextiter!(log, mvps = 1)
49+
50+
# Preconditioner: c = Pl \ r
51+
solve!(c, Pl, r)
52+
53+
ρ_prev = ρ
54+
ρ = dot(c, r)
55+
β = -ρ / ρ_prev
56+
57+
# u := r - βu (almost an axpy)
58+
@blas! u *= -β
59+
@blas! u += one(T) * c
60+
61+
# c = A * u
62+
append!(K, u)
63+
nextvec!(c, K)
64+
α = ρ / dot(u, c)
65+
66+
# Improve solution and residual
67+
@blas! x += α * u
68+
@blas! r -= α * c
69+
70+
iter += 1
71+
last_residual = norm(r)
72+
73+
# Log progress
74+
push!(log, :resnorm, last_residual)
75+
verbose && @printf("%3d\t%1.2e\n", iter, last_residual)
6176
end
77+
6278
verbose && @printf("\n")
63-
setconv(log, 0<=norm(r)<tol)
79+
setconv(log, last_residual < reltol)
80+
6481
x
6582
end
6683

@@ -145,4 +162,4 @@ end
145162

146163
@doc docstring[1] -> cg
147164
@doc docstring[2] -> cg!
148-
end
165+
end

test/getDivGrad.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ function spdiags(B,d,m,n)
3333
a[(round(Int, len[k])+1):round(Int, len[k+1]),:] = [i i+d[k] B[i+(m >= n)*d[k], k]]
3434
end
3535

36-
sparse(round(Int, a[:,1]), round(Int, a[:,2]), a[:,3], m, n)
36+
sparse(round.(Int, a[:,1]), round.(Int, a[:,2]), a[:,3], m, n)
3737
end

0 commit comments

Comments
 (0)