Skip to content

Commit c1c6ded

Browse files
authored
Merge pull request #193 from mohamed82008/master
Allow pre-allocation and re-use of buffers in CG
2 parents 756edef + 9a5d488 commit c1c6ded

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

src/cg.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import Base: start, next, done
22

3-
export cg, cg!, CGIterable, PCGIterable, cg_iterator!
3+
export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables
44

55
mutable struct CGIterable{matT, solT, vecT, numT <: Real}
66
A::matT
@@ -90,13 +90,32 @@ end
9090

9191
# Utility functions
9292

93+
"""
94+
Intermediate CG state variables to be used inside cg and cg!. `u`, `r` and `c` should be of the same type as the solution of `cg` or `cg!`.
95+
```
96+
struct CGStateVariables{T,Tx<:AbstractArray{T}}
97+
u::Tx
98+
r::Tx
99+
c::Tx
100+
end
101+
```
102+
"""
103+
struct CGStateVariables{T,Tx<:AbstractArray{T}}
104+
u::Tx
105+
r::Tx
106+
c::Tx
107+
end
108+
93109
function cg_iterator!(x, A, b, Pl = Identity();
94110
tol = sqrt(eps(real(eltype(b)))),
95111
maxiter::Int = size(A, 2),
112+
statevars::CGStateVariables = CGStateVariables{eltype(x),typeof(x)}(zeros(x), similar(x), similar(x)),
96113
initially_zero::Bool = false
97114
)
98-
u = zeros(x)
99-
r = similar(x)
115+
u = statevars.u
116+
r = statevars.r
117+
c = statevars.c
118+
u .= zero(eltype(x))
100119
copy!(r, b)
101120

102121
# Compute r with an MV-product or not.
@@ -107,7 +126,7 @@ function cg_iterator!(x, A, b, Pl = Identity();
107126
reltol = residual * tol # Save one dot product
108127
else
109128
mv_products = 1
110-
c = A * x
129+
A_mul_B!(c, A, x)
111130
r .-= c
112131
residual = norm(r)
113132
reltol = norm(b) * tol
@@ -145,15 +164,16 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...)
145164
146165
## Keywords
147166
167+
- `statevars::CGStateVariables`: Has 3 arrays similar to `x` to hold intermediate results;
148168
- `initially_zero::Bool`: If `true` assumes that `iszero(x)` so that one
149169
matrix-vector product can be saved when computing the initial
150170
residual vector;
151171
- `Pl = Identity()`: left preconditioner of the method. Should be symmetric,
152-
positive-definite like `A`.
172+
positive-definite like `A`;
153173
- `tol::Real = sqrt(eps(real(eltype(b))))`: tolerance for stopping condition `|r_k| / |r_0| ≤ tol`;
154174
- `maxiter::Int = size(A,2)`: maximum number of iterations;
155175
- `verbose::Bool = false`: print method information;
156-
- `log::Bool = false`: keep track of the residual norm in each iteration;
176+
- `log::Bool = false`: keep track of the residual norm in each iteration.
157177
158178
# Output
159179
@@ -175,6 +195,7 @@ function cg!(x, A, b;
175195
tol = sqrt(eps(real(eltype(b)))),
176196
maxiter::Int = size(A, 2),
177197
log::Bool = false,
198+
statevars::CGStateVariables = CGStateVariables{eltype(x), typeof(x)}(zeros(x), similar(x), similar(x)),
178199
verbose::Bool = false,
179200
Pl = Identity(),
180201
kwargs...
@@ -184,7 +205,7 @@ function cg!(x, A, b;
184205
log && reserve!(history, :resnorm, maxiter + 1)
185206

186207
# Actually perform CG
187-
iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, kwargs...)
208+
iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, statevars = statevars, kwargs...)
188209
if log
189210
history.mvps = iterable.mv_products
190211
end

0 commit comments

Comments
 (0)