1
+ import Base: start, next, done
2
+
1
3
export gmres, gmres!
2
4
3
- gmres (A, b; kwargs... ) = gmres! (zeros (b), A, b; kwargs... )
5
+ type ArnoldiDecomp{T, matT}
6
+ A:: matT
7
+ V:: Matrix{T} # Orthonormal basis vectors
8
+ H:: Matrix{T} # Hessenberg matrix
9
+ end
4
10
5
- function gmres! (x, A, b;
6
- Pl = Identity (),
7
- Pr = Identity (),
8
- tol = sqrt (eps (real (eltype (b)))),
9
- restart:: Int = min (20 , length (b)),
10
- maxiter:: Int = restart,
11
- log:: Bool = false ,
12
- kwargs...
11
+ ArnoldiDecomp {matT} (A:: matT , order:: Int , T:: Type ) = ArnoldiDecomp {T, matT} (
12
+ A,
13
+ zeros (T, size (A, 1 ), order + 1 ),
14
+ zeros (T, order + 1 , order)
13
15
)
14
- history = ConvergenceHistory (partial = ! log, restart = restart)
15
- history[ :tol ] = tol
16
- log && reserve! (history, :resnorm , maxiter)
17
- gmres_method! (history, x, A, b; Pl = Pl, Pr = Pr, tol = tol, maxiter = maxiter, restart = restart, log = log, kwargs ... )
18
- log && shrink! (history)
19
- log ? (x, history) : x
16
+
17
+ type Residual{numT, resT}
18
+ current :: resT # Current, absolute, preconditioned residual
19
+ accumulator :: resT # Used to compute the residual on the go
20
+ nullvec :: Vector{numT} # Vector in the null space of H to compute residuals
21
+ β :: resT # the initial residual
20
22
end
21
23
22
- function gmres_method! (history:: ConvergenceHistory , x, A, b;
23
- Pl = Identity (),
24
- Pr = Identity (),
25
- tol = sqrt (eps (real (eltype (b)))),
26
- restart:: Int = min (20 , length (b)),
27
- outer:: Int = 1 ,
28
- maxiter:: Int = restart,
29
- verbose:: Bool = false ,
30
- log = false
24
+ Residual (order, T:: Type ) = Residual {T, real(T)} (
25
+ one (real (T)),
26
+ one (real (T)),
27
+ ones (T, order + 1 ),
28
+ one (real (T))
31
29
)
32
- T = eltype (b)
33
30
34
- # Approximate solution
35
- arnoldi = ArnoldiDecomp (A, restart, T)
36
- residual = Residual (restart, T)
31
+ type GMRESIterable{preclT, precrT, vecT <: AbstractVector , arnoldiT <: ArnoldiDecomp , residualT <: Residual , resT <: Real }
32
+ Pl:: preclT
33
+ Pr:: precrT
34
+ x:: vecT
35
+ b:: vecT
36
+ Ax:: vecT # Some room to work in.
37
+
38
+ arnoldi:: arnoldiT
39
+ residual:: residualT
40
+
41
+ mv_products:: Int
42
+ restart:: Int
43
+ k:: Int
44
+ maxiter:: Int
45
+ reltol:: resT
46
+ β:: resT
47
+ end
37
48
38
- # Workspace vector to reduce the # allocs.
39
- reserved_vec = similar (b)
40
- β = residual. current = init! (arnoldi, x, b, Pl, reserved_vec)
41
- init_residual! (residual, β)
49
+ converged (g:: GMRESIterable ) = g. residual. current ≤ g. reltol
42
50
43
- # Log the first mvp for computing the initial residual
44
- if log
45
- history. mvps += 1
46
- end
51
+ start (:: GMRESIterable ) = 0
47
52
48
- # Stopping criterion is based on |r0| / |rk|
49
- reltol = residual. current * tol
53
+ done (g:: GMRESIterable , iteration:: Int ) = iteration ≥ g. maxiter || converged (g)
50
54
51
- # Total iterations (not reset after restart)
52
- total_iter = 1
55
+ function next (g:: GMRESIterable , iteration:: Int )
53
56
54
- while total_iter ≤ maxiter
57
+ # Arnoldi step: expand
58
+ expand! (g. arnoldi, g. Pl, g. Pr, g. k, g. Ax)
59
+ g. mv_products += 1
55
60
56
- # We already have the initial residual
57
- if total_iter > 1
61
+ # Orthogonalize V[:, k + 1] w.r.t. V[:, 1 : k]
62
+ g. arnoldi. H[g. k + 1 , g. k] = orthogonalize_and_normalize! (
63
+ view (g. arnoldi. V, :, 1 : g. k),
64
+ view (g. arnoldi. V, :, g. k + 1 ),
65
+ view (g. arnoldi. H, 1 : g. k, g. k)
66
+ )
58
67
59
- # Set the first basis vector
60
- β = init! (arnoldi, x, b, Pl, reserved_vec )
68
+ # Implicitly computes the residual
69
+ update_residual! (g . residual, g . arnoldi, g . k )
61
70
62
- # And initialize the residual
63
- init_residual! (residual, β)
64
-
65
- if log
66
- history. mvps += 1
67
- end
68
- end
71
+ g. k += 1
69
72
70
- # Inner iterations k = 1, ..., restart
71
- k = 1
72
-
73
- while residual. current > reltol && k ≤ restart && total_iter ≤ maxiter
74
-
75
- # Arnoldi step: expand
76
- expand! (arnoldi, Pl, Pr, k)
77
-
78
- # Orthogonalize V[:, k + 1] w.r.t. V[:, 1 : k]
79
- arnoldi. H[k + 1 , k] = orthogonalize_and_normalize! (
80
- view (arnoldi. V, :, 1 : k),
81
- view (arnoldi. V, :, k + 1 ),
82
- view (arnoldi. H, 1 : k, k)
83
- )
84
-
85
- # Implicitly computes the residual
86
- update_residual! (residual, arnoldi, k)
87
-
88
- if log
89
- nextiter! (history, mvps = 1 )
90
- push! (history, :resnorm , residual. current)
91
- end
92
-
93
- verbose && @printf (" %3d\t %3d\t %1.2e\n " , mod (total_iter, restart), k, residual. current)
94
-
95
- k += 1
96
- total_iter += 1
97
- end
73
+ # Computation of x only at the end of the iterations
74
+ # and at restart.
75
+ if g. k == g. restart + 1 || done (g, iteration + 1 )
98
76
99
77
# Solve the projected problem Hy = β * e1 in the least-squares sense
100
- rhs = solve_least_squares! (arnoldi, β, k)
78
+ rhs = solve_least_squares! (g . arnoldi, g . β, g . k)
101
79
102
80
# And improve the solution x ← x + Pr \ (V * y)
103
- update_solution! (x, view (rhs, 1 : k - 1 ), arnoldi, Pr, k)
104
-
105
- # Converged?
106
- if residual. current ≤ reltol
107
- setconv (history, true )
108
- break
81
+ update_solution! (g. x, view (rhs, 1 : g. k - 1 ), g. arnoldi, g. Pr, g. k, g. Ax)
82
+
83
+ g. k = 1
84
+
85
+ # Restart when not done.
86
+ if ! done (g, iteration)
87
+
88
+ # Set the first basis vector
89
+ g. β = init! (g. arnoldi, g. x, g. b, g. Pl, g. Ax)
90
+
91
+ # And initialize the residual
92
+ init_residual! (g. residual, g. β)
93
+
94
+ g. mv_products += 1
109
95
end
110
96
end
111
97
112
- verbose && @printf (" \n " )
113
- x
98
+ g. residual. current, iteration + 1
114
99
end
115
100
116
- type ArnoldiDecomp{T}
117
- A
118
- V:: Matrix{T} # Orthonormal basis vectors
119
- H:: Matrix{T} # Hessenberg matrix
120
- end
101
+ gmres_iterable (A, b; kwargs... ) = gmres_iterable! (zeros (b), A, b; initially_zero = true , kwargs... )
121
102
122
- ArnoldiDecomp (A, order:: Int , T:: Type ) = ArnoldiDecomp {T} (
123
- A,
124
- zeros (T, size (A, 1 ), order + 1 ),
125
- zeros (T, order + 1 , order)
103
+ function gmres_iterable! (x, A, b;
104
+ Pl = Identity (),
105
+ Pr = Identity (),
106
+ tol = sqrt (eps (real (eltype (b)))),
107
+ restart:: Int = min (20 , length (b)),
108
+ maxiter:: Int = restart,
109
+ initially_zero = false
126
110
)
111
+ T = eltype (b)
127
112
128
- type Residual{numT, resT}
129
- current:: resT # Current relative residual
130
- accumulator:: resT # Used to compute the residual on the go
131
- nullvec:: Vector{numT} # Vector in the null space of H to compute residuals
132
- β:: resT # the initial residual
113
+ # Approximate solution
114
+ arnoldi = ArnoldiDecomp (A, restart, T)
115
+ residual = Residual (restart, T)
116
+ mv_products = initially_zero == true ? 1 : 0
117
+
118
+ # Workspace vector to reduce the # allocs.
119
+ Ax = similar (b)
120
+ residual. current = init! (arnoldi, x, b, Pl, Ax, initially_zero = initially_zero)
121
+ init_residual! (residual, residual. current)
122
+
123
+ reltol = tol * residual. current
124
+
125
+ GMRESIterable (Pl, Pr, x, b, Ax,
126
+ arnoldi, residual,
127
+ mv_products, restart, 1 , maxiter, reltol, residual. current
128
+ )
133
129
end
134
130
135
- Residual (order, T:: Type ) = Residual {T, real(T)} (
136
- one (real (T)),
137
- one (real (T)),
138
- ones (T, order + 1 ),
139
- one (real (T))
131
+ gmres (A, b; kwargs... ) = gmres! (zeros (b), A, b; initially_zero = true , kwargs... )
132
+
133
+ function gmres! (x, A, b;
134
+ Pl = Identity (),
135
+ Pr = Identity (),
136
+ tol = sqrt (eps (real (eltype (b)))),
137
+ restart:: Int = min (20 , length (b)),
138
+ maxiter:: Int = restart,
139
+ log:: Bool = false ,
140
+ initially_zero = false ,
141
+ verbose:: Bool = false
140
142
)
143
+ history = ConvergenceHistory (partial = ! log, restart = restart)
144
+ history[:tol ] = tol
145
+ log && reserve! (history, :resnorm , maxiter)
146
+
147
+ iterable = gmres_iterable! (x, A, b; Pl = Pl, Pr = Pr, tol = tol, maxiter = maxiter, restart = restart, initially_zero = initially_zero)
148
+
149
+ verbose && @printf (" === gmres ===\n %4s\t %4s\t %7s\n " ," rest" ," iter" ," resnorm" )
150
+
151
+ for (iteration, residual) = enumerate (iterable)
152
+ if log
153
+ nextiter! (history)
154
+ history. mvps = iterable. mv_products
155
+ push! (history, :resnorm , residual)
156
+ end
157
+
158
+ verbose && @printf (" %3d\t %3d\t %1.2e\n " , 1 + div (iteration - 1 , restart), 1 + mod (iteration - 1 , restart), residual)
159
+ end
160
+
161
+ verbose && println ()
162
+ setconv (history, converged (iterable))
163
+ log && shrink! (history)
164
+
165
+ log ? (x, history) : x
166
+ end
141
167
142
168
function update_residual! (r:: Residual , arnoldi:: ArnoldiDecomp , k:: Int )
143
169
# Cheaply computes the current residual
@@ -146,15 +172,20 @@ function update_residual!(r::Residual, arnoldi::ArnoldiDecomp, k::Int)
146
172
r. current = r. β / √ r. accumulator
147
173
end
148
174
149
- function init! {T} (arnoldi:: ArnoldiDecomp{T} , x, b, Pl, reserved_vec )
175
+ function init! {T} (arnoldi:: ArnoldiDecomp{T} , x, b, Pl, Ax; initially_zero :: Bool = false )
150
176
# Initialize the Krylov subspace with the initial residual vector
151
177
# This basically does V[1] = Pl \ (b - A * x) and then normalize
152
178
153
179
first_col = view (arnoldi. V, :, 1 )
154
180
155
181
copy! (first_col, b)
156
- A_mul_B! (reserved_vec, arnoldi. A, x)
157
- @blas! first_col -= one (T) * reserved_vec
182
+
183
+ # Potentially save one MV product
184
+ if ! initially_zero
185
+ A_mul_B! (Ax, arnoldi. A, x)
186
+ @blas! first_col -= one (T) * Ax
187
+ end
188
+
158
189
A_ldiv_B! (Pl, first_col)
159
190
160
191
# Normalize
@@ -179,33 +210,37 @@ function solve_least_squares!{T}(arnoldi::ArnoldiDecomp{T}, β, k::Int)
179
210
rhs
180
211
end
181
212
182
- function update_solution! {T} (x, y, arnoldi:: ArnoldiDecomp{T} , Pr:: Identity , k:: Int )
213
+ function update_solution! {T} (x, y, arnoldi:: ArnoldiDecomp{T} , Pr:: Identity , k:: Int , Ax )
183
214
# Update x ← x + V * y
184
215
185
216
# TODO : find the SugarBLAS alternative
186
217
BLAS. gemv! (' N' , one (T), view (arnoldi. V, :, 1 : k - 1 ), y, one (T), x)
187
218
end
188
219
189
- function update_solution! {T} (x, y, arnoldi:: ArnoldiDecomp{T} , Pr, k:: Int )
190
- # Allocates a temporary while computing x ← x + Pr \ (V * y)
191
- tmp = view (arnoldi. V, :, 1 : k - 1 ) * y
192
- @blas! x += one (T) * (Pr \ tmp)
220
+ function update_solution! {T} (x, y, arnoldi:: ArnoldiDecomp{T} , Pr, k:: Int , Ax)
221
+ # Computing x ← x + Pr \ (V * y) and use Ax as a work space
222
+ A_mul_B! (Ax, view (arnoldi. V, :, 1 : k - 1 ), y)
223
+ A_ldiv_B! (Pr, Ax)
224
+ @blas! x += one (T) * Ax
193
225
end
194
226
195
- function expand! (arnoldi:: ArnoldiDecomp , Pl:: Identity , Pr:: Identity , k:: Int )
227
+ function expand! (arnoldi:: ArnoldiDecomp , Pl:: Identity , Pr:: Identity , k:: Int , Ax )
196
228
# Simply expands by A * v without allocating
197
229
A_mul_B! (view (arnoldi. V, :, k + 1 ), arnoldi. A, view (arnoldi. V, :, k))
198
230
end
199
231
200
- function expand! (arnoldi:: ArnoldiDecomp , Pl, Pr:: Identity , k:: Int )
232
+ function expand! (arnoldi:: ArnoldiDecomp , Pl, Pr:: Identity , k:: Int , Ax )
201
233
# Expands by Pl \ (A * v) without allocating
202
- A_mul_B! (view (arnoldi. V, :, k + 1 ), arnoldi. A, view (arnoldi. V, :, k))
203
- A_ldiv_B! (Pl, view (arnoldi. V, :, k + 1 ))
234
+ nextV = view (arnoldi. V, :, k + 1 )
235
+ A_mul_B! (nextV, arnoldi. A, view (arnoldi. V, :, k))
236
+ A_ldiv_B! (Pl, nextV)
204
237
end
205
238
206
- function expand! (arnoldi:: ArnoldiDecomp , Pl, Pr, k:: Int )
207
- # Expands by Pl \ (A * (Pr \ v)). Allocates one vector.
208
- A_ldiv_B! (view (arnoldi. V, :, k + 1 ), Pr, view (arnoldi. V, :, k))
209
- copy! (view (arnoldi. V, :, k + 1 ), arnoldi. A * view (arnoldi. V, :, k + 1 ))
210
- A_ldiv_B! (Pl, view (arnoldi. V, :, k + 1 ))
239
+ function expand! (arnoldi:: ArnoldiDecomp , Pl, Pr, k:: Int , Ax)
240
+ # Expands by Pl \ (A * (Pr \ v)). Avoids allocation by using Ax.
241
+ nextV = view (arnoldi. V, :, k + 1 )
242
+ A_ldiv_B! (nextV, Pr, view (arnoldi. V, :, k))
243
+ A_mul_B! (Ax, arnoldi. A, nextV)
244
+ copy! (nextV, Ax)
245
+ A_ldiv_B! (Pl, nextV)
211
246
end
0 commit comments