1
1
import Base: start, next, done
2
2
3
- export cg, cg!, CGIterable, PCGIterable, cg_iterator!
3
+ export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables
4
4
5
5
mutable struct CGIterable{matT, solT, vecT, numT <: Real }
6
6
A:: matT
90
90
91
91
# Utility functions
92
92
93
+ """
94
+ Intermediate CG state variables to be used inside cg and cg!. `u`, `r` and `c` should be of the same type as the solution of `cg` or `cg!`.
95
+ ```
96
+ struct CGStateVariables{T,Tx<:AbstractArray{T}}
97
+ u::Tx
98
+ r::Tx
99
+ c::Tx
100
+ end
101
+ ```
102
+ """
103
+ struct CGStateVariables{T,Tx<: AbstractArray{T} }
104
+ u:: Tx
105
+ r:: Tx
106
+ c:: Tx
107
+ end
108
+
93
109
function cg_iterator! (x, A, b, Pl = Identity ();
94
110
tol = sqrt (eps (real (eltype (b)))),
95
111
maxiter:: Int = size (A, 2 ),
112
+ statevars:: CGStateVariables = CGStateVariables {eltype(x),typeof(x)} (zeros (x), similar (x), similar (x)),
96
113
initially_zero:: Bool = false
97
114
)
98
- u = zeros (x)
99
- r = similar (x)
115
+ u = statevars. u
116
+ r = statevars. r
117
+ c = statevars. c
118
+ u .= zero (eltype (x))
100
119
copy! (r, b)
101
120
102
121
# Compute r with an MV-product or not.
@@ -107,7 +126,7 @@ function cg_iterator!(x, A, b, Pl = Identity();
107
126
reltol = residual * tol # Save one dot product
108
127
else
109
128
mv_products = 1
110
- c = A * x
129
+ A_mul_B! (c, A, x)
111
130
r .- = c
112
131
residual = norm (r)
113
132
reltol = norm (b) * tol
@@ -145,15 +164,16 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...)
145
164
146
165
## Keywords
147
166
167
+ - `statevars::CGStateVariables`: Has 3 arrays similar to `x` to hold intermediate results;
148
168
- `initially_zero::Bool`: If `true` assumes that `iszero(x)` so that one
149
169
matrix-vector product can be saved when computing the initial
150
170
residual vector;
151
171
- `Pl = Identity()`: left preconditioner of the method. Should be symmetric,
152
- positive-definite like `A`.
172
+ positive-definite like `A`;
153
173
- `tol::Real = sqrt(eps(real(eltype(b))))`: tolerance for stopping condition `|r_k| / |r_0| ≤ tol`;
154
174
- `maxiter::Int = size(A,2)`: maximum number of iterations;
155
175
- `verbose::Bool = false`: print method information;
156
- - `log::Bool = false`: keep track of the residual norm in each iteration;
176
+ - `log::Bool = false`: keep track of the residual norm in each iteration.
157
177
158
178
# Output
159
179
@@ -175,6 +195,7 @@ function cg!(x, A, b;
175
195
tol = sqrt (eps (real (eltype (b)))),
176
196
maxiter:: Int = size (A, 2 ),
177
197
log:: Bool = false ,
198
+ statevars:: CGStateVariables = CGStateVariables {eltype(x), typeof(x)} (zeros (x), similar (x), similar (x)),
178
199
verbose:: Bool = false ,
179
200
Pl = Identity (),
180
201
kwargs...
@@ -184,7 +205,7 @@ function cg!(x, A, b;
184
205
log && reserve! (history, :resnorm , maxiter + 1 )
185
206
186
207
# Actually perform CG
187
- iterable = cg_iterator! (x, A, b, Pl; tol = tol, maxiter = maxiter, kwargs... )
208
+ iterable = cg_iterator! (x, A, b, Pl; tol = tol, maxiter = maxiter, statevars = statevars, kwargs... )
188
209
if log
189
210
history. mvps = iterable. mv_products
190
211
end
0 commit comments