Skip to content

Commit 586147a

Browse files
lostellamohamed82008
authored andcommitted
updated iteration protocol
1 parent b0f4377 commit 586147a

File tree

8 files changed

+52
-23
lines changed

8 files changed

+52
-23
lines changed

src/bicgstabl.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export bicgstabl, bicgstabl!, bicgstabl_iterator, bicgstabl_iterator!, BiCGStabIterable
22
using Printf
3-
import Base: start, next, done
3+
import Base: iterate
44

55
mutable struct BiCGStabIterable{precT, matT, solT, vecT <: AbstractVector, smallMatT <: AbstractMatrix, realT <: Real, scalarT <: Number}
66
A::matT
@@ -76,7 +76,9 @@ end
7676
@inline start(::BiCGStabIterable) = 0
7777
@inline done(it::BiCGStabIterable, iteration::Int) = it.mv_products it.max_mv_products || converged(it)
7878

79-
function next(it::BiCGStabIterable, iteration::Int)
79+
function iterate(it::BiCGStabIterable, iteration::Int=start(it))
80+
if done(it, iteration) return nothing end
81+
8082
T = eltype(it.x)
8183
L = 2 : it.l + 1
8284

src/cg.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Base: start, next, done
1+
import Base: iterate
22
using Printf
33
export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables
44

@@ -40,7 +40,9 @@ end
4040
# Ordinary CG #
4141
###############
4242

43-
function next(it::CGIterable, iteration::Int)
43+
function iterate(it::CGIterable, iteration::Int=start(it))
44+
if done(it, iteration) return nothing end
45+
4446
# u := r + βu (almost an axpy)
4547
β = it.residual^2 / it.prev_residual^2
4648
it.u .= it.r .+ β .* it.u
@@ -64,7 +66,12 @@ end
6466
# Preconditioned CG #
6567
#####################
6668

67-
function next(it::PCGIterable, iteration::Int)
69+
function iterate(it::PCGIterable, iteration::Int=start(it))
70+
# Check for termination first
71+
if done(it, iteration)
72+
return nothing
73+
end
74+
6875
ldiv!(it.c, it.Pl, it.r)
6976

7077
ρ_prev = it.ρ

src/chebyshev.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Base: next, start, done
1+
import Base: iterate
22

33
export chebyshev, chebyshev!
44

@@ -26,7 +26,9 @@ converged(c::ChebyshevIterable) = c.resnorm ≤ c.reltol
2626
start(::ChebyshevIterable) = 0
2727
done(c::ChebyshevIterable, iteration::Int) = iteration c.maxiter || converged(c)
2828

29-
function next(cheb::ChebyshevIterable, iteration::Int)
29+
function iterate(cheb::ChebyshevIterable, iteration::Int=start(cheb))
30+
if done(cheb, iteration) return nothing end
31+
3032
T = eltype(cheb.x)
3133

3234
ldiv!(cheb.c, cheb.Pl, cheb.r)

src/gmres.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Base: start, next, done
1+
import Base: iterate
22
using Printf
33
export gmres, gmres!
44

@@ -52,7 +52,8 @@ start(::GMRESIterable) = 0
5252

5353
done(g::GMRESIterable, iteration::Int) = iteration g.maxiter || converged(g)
5454

55-
function next(g::GMRESIterable, iteration::Int)
55+
function iterate(g::GMRESIterable, iteration::Int=start(g))
56+
if done(g, iteration) return nothing end
5657

5758
# Arnoldi step: expand
5859
expand!(g.arnoldi, g.Pl, g.Pr, g.k, g.Ax)

src/minres.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export minres_iterable, minres, minres!
22
using Printf
33
import LinearAlgebra: BLAS.axpy!, givensAlgorithm
4-
import Base: start, next, done
4+
import Base: iterate
55

66
mutable struct MINRESIterable{matT, solT, vecT <: DenseVector, smallVecT <: DenseVector, rotT <: Number, realT <: Real}
77
A::matT
@@ -94,7 +94,9 @@ start(::MINRESIterable) = 1
9494

9595
done(m::MINRESIterable, iteration::Int) = iteration > m.maxiter || converged(m)
9696

97-
function next(m::MINRESIterable, iteration::Int)
97+
function iterate(m::MINRESIterable, iteration::Int=start(m))
98+
if done(m, iteration) return nothing end
99+
98100
# v_next = A * v_curr - H[2] * v_prev
99101
mul!(m.v_next, m.A, m.v_curr)
100102

src/simple.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Base: start, next, done
1+
import Base: iterate
22

33
#Simple methods
44
export powm, powm!, invpowm, invpowm!
@@ -25,7 +25,8 @@ end
2525

2626
@inline done(p::PowerMethodIterable, iteration::Int) = iteration > p.maxiter || converged(p)
2727

28-
function next(p::PowerMethodIterable, iteration::Int)
28+
function iterate(p::PowerMethodIterable, iteration::Int=start(p))
29+
if done(p, iteration) return nothing end
2930

3031
mul!(p.Ax, p.A, p.x)
3132

src/stationary.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export jacobi, jacobi!, gauss_seidel, gauss_seidel!, sor, sor!, ssor, ssor!
22

33
import LinearAlgebra.SingularException
4-
import Base: start, next, done, getindex
4+
import Base: getindex, iterate
55

66
function check_diag(A::AbstractMatrix)
77
for i = 1 : size(A, 1)
@@ -45,7 +45,9 @@ end
4545

4646
start(::DenseJacobiIterable) = 1
4747
done(it::DenseJacobiIterable, iteration::Int) = iteration > it.maxiter
48-
function next(j::DenseJacobiIterable, iteration::Int)
48+
function iterate(j::DenseJacobiIterable, iteration::Int=start(j))
49+
if done(j, iteration) return nothing end
50+
4951
n = size(j.A, 1)
5052

5153
copyto!(j.next, j.b)
@@ -103,7 +105,9 @@ end
103105
start(::DenseGaussSeidelIterable) = 1
104106
done(it::DenseGaussSeidelIterable, iteration::Int) = iteration > it.maxiter
105107

106-
function next(s::DenseGaussSeidelIterable, iteration::Int)
108+
function iterate(s::DenseGaussSeidelIterable, iteration::Int=start(s))
109+
if done(s, iteration) return nothing end
110+
107111
n = size(s.A, 1)
108112

109113
for col = 1 : n
@@ -160,7 +164,9 @@ end
160164

161165
start(::DenseSORIterable) = 1
162166
done(it::DenseSORIterable, iteration::Int) = iteration > it.maxiter
163-
function next(s::DenseSORIterable, iteration::Int)
167+
function iterate(s::DenseSORIterable, iteration::Int=start(s))
168+
if done(s, iteration) return nothing end
169+
164170
n = size(s.A, 1)
165171

166172
for col = 1 : n
@@ -218,7 +224,9 @@ end
218224

219225
start(::DenseSSORIterable) = 1
220226
done(it::DenseSSORIterable, iteration::Int) = iteration > it.maxiter
221-
function next(s::DenseSSORIterable, iteration::Int)
227+
function iterate(s::DenseSSORIterable, iteration::Int=start(s))
228+
if done(s, iteration) return nothing end
229+
222230
n = size(s.A, 1)
223231

224232
for col = 1 : n

src/stationary_sparse.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import LinearAlgebra: mul!, ldiv!
2-
import Base: start, next, done, getindex
2+
import Base: getindex, iterate
33

44
using SparseArrays
55

@@ -222,7 +222,8 @@ end
222222

223223
start(::JacobiIterable) = 1
224224
done(j::JacobiIterable, iteration::Int) = iteration > j.maxiter
225-
function next(j::JacobiIterable{T}, iteration::Int) where {T}
225+
function iterate(j::JacobiIterable{T}, iteration::Int=start(j)) where {T}
226+
if done(j, iteration) return nothing end
226227
# tmp = D \ (b - (A - D) * x)
227228
copyto!(j.next, j.b)
228229
mul!(-one(T), j.O, j.x, one(T), j.next)
@@ -273,7 +274,8 @@ end
273274

274275
start(::GaussSeidelIterable) = 1
275276
done(g::GaussSeidelIterable, iteration::Int) = iteration > g.maxiter
276-
function next(g::GaussSeidelIterable, iteration::Int)
277+
function iterate(g::GaussSeidelIterable, iteration::Int=start(g))
278+
if done(g, iteration) return nothing end
277279
# x ← L \ (-U * x + b)
278280
T = eltype(g.x)
279281
gauss_seidel_multiply!(-one(T), g.U, g.x, one(T), g.b, g.x)
@@ -316,7 +318,9 @@ end
316318

317319
start(::SORIterable) = 1
318320
done(s::SORIterable, iteration::Int) = iteration > s.maxiter
319-
function next(s::SORIterable{T}, iteration::Int) where {T}
321+
function iterate(s::SORIterable{T}, iteration::Int=start(s)) where {T}
322+
if done(s, iteration) return nothing end
323+
320324
# next = b - U * x
321325
gauss_seidel_multiply!(-one(T), s.U, s.x, one(T), s.b, s.next)
322326

@@ -384,7 +388,9 @@ end
384388
start(s::SSORIterable) = 1
385389
done(s::SSORIterable, iteration::Int) = iteration > s.maxiter
386390

387-
function next(s::SSORIterable{T}, iteration::Int) where {T}
391+
function iterate(s::SSORIterable{T}, iteration::Int=start(s)) where {T}
392+
if done(s, iteration) return nothing end
393+
388394
# tmp = b - U * x
389395
gauss_seidel_multiply!(-one(T), s.sU, s.x, one(T), s.b, s.tmp)
390396

0 commit comments

Comments
 (0)