Skip to content

Commit 2c75164

Browse files
author
mohamed82008
committed
Allow pre-allocation and re-use of buffers
1 parent 756edef commit 2c75164

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

src/cg.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import Base: start, next, done
22

3-
export cg, cg!, CGIterable, PCGIterable, cg_iterator!
3+
export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables
44

55
mutable struct CGIterable{matT, solT, vecT, numT <: Real}
66
A::matT
@@ -90,13 +90,22 @@ end
9090

9191
# Utility functions
9292

93+
struct CGStateVariables{T}
94+
u::Vector{T}
95+
r::Vector{T}
96+
c::Vector{T}
97+
end
98+
9399
function cg_iterator!(x, A, b, Pl = Identity();
94100
tol = sqrt(eps(real(eltype(b)))),
95101
maxiter::Int = size(A, 2),
102+
statevars::CGStateVariables = CGStateVariables{eltype(x)}(zeros(x), similar(x), similar(x)),
96103
initially_zero::Bool = false
97104
)
98-
u = zeros(x)
99-
r = similar(x)
105+
u = statevars.u
106+
r = statevars.r
107+
c = statevars.c
108+
u .= zero(eltype(x))
100109
copy!(r, b)
101110

102111
# Compute r with an MV-product or not.
@@ -107,7 +116,7 @@ function cg_iterator!(x, A, b, Pl = Identity();
107116
reltol = residual * tol # Save one dot product
108117
else
109118
mv_products = 1
110-
c = A * x
119+
A_mul_B!(c, A, x)
111120
r .-= c
112121
residual = norm(r)
113122
reltol = norm(b) * tol
@@ -175,6 +184,7 @@ function cg!(x, A, b;
175184
tol = sqrt(eps(real(eltype(b)))),
176185
maxiter::Int = size(A, 2),
177186
log::Bool = false,
187+
statevars::CGStateVariables = CGStateVariables{eltype(x)}(zeros(x), similar(x), similar(x)),
178188
verbose::Bool = false,
179189
Pl = Identity(),
180190
kwargs...
@@ -184,7 +194,7 @@ function cg!(x, A, b;
184194
log && reserve!(history, :resnorm, maxiter + 1)
185195

186196
# Actually perform CG
187-
iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, kwargs...)
197+
iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, statevars = statevars, kwargs...)
188198
if log
189199
history.mvps = iterable.mv_products
190200
end

0 commit comments

Comments
 (0)