Skip to content

Commit 0aa0760

Browse files
haampieandreasnoack
authored andcommitted
Fix element types (#163)
* Remove unused methods * Clean up common.jl * Ensure r has the type of x, not b. Remove b from the iterable since it is not used anyway * Make residual type in BiCGStab(l) equal to solution type and remove rhs from iterable * Chebyshev: residual type should equal x's type and remove b from iterable * Similar story for GMRES * Use x element type everywhere * Use different types for solution, temporary and rhs in iterables of stationary methods * zerox over zeros
1 parent 1d036a3 commit 0aa0760

File tree

8 files changed

+82
-118
lines changed

8 files changed

+82
-118
lines changed

src/bicgstabl.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@ export bicgstabl, bicgstabl!, bicgstabl_iterator, bicgstabl_iterator!, BiCGStabI
22

33
import Base: start, next, done
44

5-
mutable struct BiCGStabIterable{precT, matT, vecT <: AbstractVector, smallMatT <: AbstractMatrix, realT <: Real, scalarT <: Number}
5+
mutable struct BiCGStabIterable{precT, matT, solT, vecT <: AbstractVector, smallMatT <: AbstractMatrix, realT <: Real, scalarT <: Number}
66
A::matT
7-
b::vecT
87
l::Int
98

10-
x::vecT
9+
x::solT
1110
r_shadow::vecT
1211
rs::smallMatT
1312
us::smallMatT
@@ -33,7 +32,7 @@ function bicgstabl_iterator!(x, A, b, l::Int = 2;
3332
initial_zero = false,
3433
tol = sqrt(eps(real(eltype(b))))
3534
)
36-
T = eltype(b)
35+
T = eltype(x)
3736
n = size(A, 1)
3837
mv_products = 0
3938

@@ -69,7 +68,7 @@ function bicgstabl_iterator!(x, A, b, l::Int = 2;
6968
# Stopping condition based on relative tolerance.
7069
reltol = nrm * tol
7170

72-
BiCGStabIterable(A, b, l, x, r_shadow, rs, us,
71+
BiCGStabIterable(A, l, x, r_shadow, rs, us,
7372
max_mv_products, mv_products, reltol, nrm,
7473
Pl,
7574
γ, ω, σ, M
@@ -81,7 +80,7 @@ end
8180
@inline done(it::BiCGStabIterable, iteration::Int) = it.mv_products it.max_mv_products || converged(it)
8281

8382
function next(it::BiCGStabIterable, iteration::Int)
84-
T = eltype(it.b)
83+
T = eltype(it.x)
8584
L = 2 : it.l + 1
8685

8786
it.σ = -it.ω * it.σ

src/cg.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ import Base: start, next, done
22

33
export cg, cg!, CGIterable, PCGIterable, cg_iterator, cg_iterator!
44

5-
mutable struct CGIterable{matT, vecT <: AbstractVector, numT <: Real}
5+
mutable struct CGIterable{matT, solT, vecT, numT <: Real}
66
A::matT
7-
x::vecT
8-
b::vecT
7+
x::solT
98
r::vecT
109
c::vecT
1110
u::vecT
@@ -16,11 +15,10 @@ mutable struct CGIterable{matT, vecT <: AbstractVector, numT <: Real}
1615
mv_products::Int
1716
end
1817

19-
mutable struct PCGIterable{precT, matT, vecT <: AbstractVector, numT <: Real, paramT <: Number}
18+
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number}
2019
Pl::precT
2120
A::matT
22-
x::vecT
23-
b::vecT
21+
x::solT
2422
r::vecT
2523
c::vecT
2624
u::vecT
@@ -46,7 +44,7 @@ function next(it::CGIterable, iteration::Int)
4644
# u := r + βu (almost an axpy)
4745
β = it.residual^2 / it.prev_residual^2
4846
@blas! it.u *= β
49-
@blas! it.u += one(eltype(it.b)) * it.r
47+
@blas! it.u += one(eltype(it.u)) * it.r
5048

5149
# c = A * u
5250
A_mul_B!(it.c, it.A, it.u)
@@ -76,7 +74,7 @@ function next(it::PCGIterable, iteration::Int)
7674
# u := c + βu (almost an axpy)
7775
β = it.ρ / ρ_prev
7876
@blas! it.u *= β
79-
@blas! it.u += one(eltype(it.b)) * it.c
77+
@blas! it.u += one(eltype(it.u)) * it.c
8078

8179
# c = A * u
8280
A_mul_B!(it.c, it.A, it.u)
@@ -102,7 +100,8 @@ function cg_iterator!(x, A, b, Pl = Identity();
102100
initially_zero::Bool = false
103101
)
104102
u = zeros(x)
105-
r = copy(b)
103+
r = similar(x)
104+
copy!(r, b)
106105

107106
# Compute r with an MV-product or not.
108107
if initially_zero
@@ -120,16 +119,13 @@ function cg_iterator!(x, A, b, Pl = Identity();
120119

121120
# Return the iterable
122121
if isa(Pl, Identity)
123-
return CGIterable(A, x, b,
124-
r, c, u,
122+
return CGIterable(A, x, r, c, u,
125123
reltol, residual, one(residual),
126124
maxiter, mv_products
127125
)
128126
else
129-
ρ = one(eltype(r))
130-
return PCGIterable(Pl, A, x, b,
131-
r, c, u,
132-
reltol, residual, ρ,
127+
return PCGIterable(Pl, A, x, r, c, u,
128+
reltol, residual, one(eltype(x)),
133129
maxiter, mv_products
134130
)
135131
end

src/chebyshev.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@ import Base: next, start, done
22

33
export chebyshev, chebyshev!
44

5-
mutable struct ChebyshevIterable{precT, matT, vecT, realT <: Real}
5+
mutable struct ChebyshevIterable{precT, matT, solT, vecT, realT <: Real}
66
Pl::precT
77
A::matT
8-
b::vecT
98

10-
x::vecT
9+
x::solT
1110
r::vecT
1211
u::vecT
1312
c::vecT
@@ -28,7 +27,7 @@ start(::ChebyshevIterable) = 0
2827
done(c::ChebyshevIterable, iteration::Int) = iteration c.maxiter || converged(c)
2928

3029
function next(cheb::ChebyshevIterable, iteration::Int)
31-
T = eltype(cheb.u)
30+
T = eltype(cheb.x)
3231

3332
solve!(cheb.c, cheb.Pl, cheb.r)
3433

@@ -64,8 +63,9 @@ function chebyshev_iterable!(x, A, b, λmin::Real, λmax::Real;
6463
λ_avg = (λmax + λmin) / 2
6564
λ_diff = (λmax - λmin) / 2
6665

67-
T = eltype(b)
68-
r = copy(b)
66+
T = eltype(x)
67+
r = similar(x)
68+
copy!(r, b)
6969
u = zeros(x)
7070
c = similar(x)
7171

@@ -82,8 +82,7 @@ function chebyshev_iterable!(x, A, b, λmin::Real, λmax::Real;
8282
mv_products = 1
8383
end
8484

85-
ChebyshevIterable(Pl, A, b,
86-
x, r, u, c,
85+
ChebyshevIterable(Pl, A, x, r, u, c,
8786
zero(real(T)),
8887
λ_avg, λ_diff,
8988
resnorm, reltol, maxiter, mv_products

src/common.jl

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,60 +12,23 @@ Determine type of the division of an element of `b` against an element of `A`:
1212
"""
1313
Adivtype(A, b) = typeof(one(eltype(b))/one(eltype(A)))
1414

15-
"""
16-
Amultype(A, x)
17-
Determine type of the multiplication of an element of `b` with an element of `A`:
18-
`typeof(one(eltype(A))*one(eltype(x)))`
19-
"""
20-
Amultype(A, x) = typeof(one(eltype(A))*one(eltype(x)))
21-
22-
"""
23-
randx(A, b)
24-
Build a random unitary vector `Vector{T}`, where `T` is `Adivtype(A,b)`.
25-
"""
26-
function randx(A, b)
27-
T = Adivtype(A, b)
28-
x = initrand!(Array(T, size(A, 2)))
29-
end
30-
3115
"""
3216
zerox(A, b)
3317
Build a zeros vector `Vector{T}`, where `T` is `Adivtype(A,b)`.
3418
"""
35-
function zerox(A, b)
36-
T = Adivtype(A, b)
37-
x = zeros(T, size(A, 2))
38-
end
19+
zerox(A, b) = zeros(Adivtype(A, b), size(A, 2))
3920

4021
#### Numerics
4122
"""
4223
solve(A,b)
43-
Solve `A\b` with a direct solver. When `A` is a function `A(b)` is dispatched instead.
24+
Solve `A\\b` with a direct solver. When `A` is a function `A(b)` is dispatched instead.
4425
"""
4526
solve(A::Function,b) = A(b)
46-
4727
solve(A,b) = A\b
48-
4928
solve!(out::AbstractArray{T},A::Int,b::AbstractArray{T}) where {T} = scale!(out,b, 1/A)
50-
5129
solve!(out::AbstractArray{T},A,b::AbstractArray{T}) where {T} = A_ldiv_B!(out,A,b)
5230
solve!(out::AbstractArray{T},A::Function,b::AbstractArray{T}) where {T} = copy!(out,A(b))
5331

54-
"""
55-
initrand!(v)
56-
Overwrite `v` with a random unitary vector of the same length.
57-
"""
58-
function initrand!(v::Vector)
59-
_randn!(v)
60-
nv = norm(v)
61-
for i = 1:length(v)
62-
v[i] /= nv
63-
end
64-
v
65-
end
66-
_randn!(v::Array{Float64}) = randn!(v)
67-
_randn!(v) = copy!(v, randn(length(v)))
68-
6932
# Identity preconditioner
7033
struct Identity end
7134

src/gmres.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ Residual(order, T::Type) = Residual{T, real(T)}(
2828
one(real(T))
2929
)
3030

31-
mutable struct GMRESIterable{preclT, precrT, vecT <: AbstractVector, arnoldiT <: ArnoldiDecomp, residualT <: Residual, resT <: Real}
31+
mutable struct GMRESIterable{preclT, precrT, solT, rhsT, vecT, arnoldiT <: ArnoldiDecomp, residualT <: Residual, resT <: Real}
3232
Pl::preclT
3333
Pr::precrT
34-
x::vecT
35-
b::vecT
34+
x::solT
35+
b::rhsT
3636
Ax::vecT # Some room to work in.
3737

3838
arnoldi::arnoldiT
@@ -98,25 +98,25 @@ function next(g::GMRESIterable, iteration::Int)
9898
g.residual.current, iteration + 1
9999
end
100100

101-
gmres_iterable(A, b; kwargs...) = gmres_iterable!(zeros(b), A, b; initially_zero = true, kwargs...)
101+
gmres_iterable(A, b; kwargs...) = gmres_iterable!(zerox(A, b), A, b; initially_zero = true, kwargs...)
102102

103103
function gmres_iterable!(x, A, b;
104104
Pl = Identity(),
105105
Pr = Identity(),
106106
tol = sqrt(eps(real(eltype(b)))),
107107
restart::Int = min(20, length(b)),
108108
maxiter::Int = restart,
109-
initially_zero = false
109+
initially_zero::Bool = false
110110
)
111-
T = eltype(b)
111+
T = eltype(x)
112112

113113
# Approximate solution
114114
arnoldi = ArnoldiDecomp(A, restart, T)
115115
residual = Residual(restart, T)
116-
mv_products = initially_zero == true ? 1 : 0
116+
mv_products = initially_zero ? 1 : 0
117117

118118
# Workspace vector to reduce the # allocs.
119-
Ax = similar(b)
119+
Ax = similar(x)
120120
residual.current = init!(arnoldi, x, b, Pl, Ax, initially_zero = initially_zero)
121121
init_residual!(residual, residual.current)
122122

@@ -133,7 +133,7 @@ end
133133
134134
Same as [`gmres!`](@ref), but allocates a solution vector `x` initialized with zeros.
135135
"""
136-
gmres(A, b; kwargs...) = gmres!(zeros(b), A, b; initially_zero = true, kwargs...)
136+
gmres(A, b; kwargs...) = gmres!(zerox(A, b), A, b; initially_zero = true, kwargs...)
137137

138138
"""
139139
gmres!(x, A, b; kwargs...) -> x, [history]

src/minres.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ export minres_iterable, minres, minres!
33
import Base.LinAlg: BLAS.axpy!, givensAlgorithm
44
import Base: start, next, done
55

6-
mutable struct MINRESIterable{matT, vecT <: DenseVector, smallVecT <: DenseVector, rotT <: Number, realT <: Real}
6+
mutable struct MINRESIterable{matT, solT, vecT <: DenseVector, smallVecT <: DenseVector, rotT <: Number, realT <: Real}
77
A::matT
88
skew_hermitian::Bool
9-
x::vecT
9+
x::solT
1010

1111
# Krylov basis vectors
1212
v_prev::vecT
@@ -44,15 +44,16 @@ function minres_iterable!(x, A, b;
4444
tol = sqrt(eps(real(eltype(b)))),
4545
maxiter = size(A, 1)
4646
)
47-
T = eltype(b)
47+
T = eltype(x)
4848
HessenbergT = skew_hermitian ? T : real(T)
4949

50-
v_prev = similar(b)
51-
v_curr = copy(b)
52-
v_next = similar(b)
53-
w_prev = similar(b)
54-
w_curr = similar(b)
55-
w_next = similar(b)
50+
v_prev = similar(x)
51+
v_curr = similar(x)
52+
copy!(v_curr, b)
53+
v_next = similar(x)
54+
w_prev = similar(x)
55+
w_curr = similar(x)
56+
w_next = similar(x)
5657

5758
mv_products = 0
5859

src/stationary.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ function jacobi!(x, A::AbstractMatrix, b; maxiter::Int=10)
3535
x
3636
end
3737

38-
mutable struct DenseJacobiIterable{matT,vecT}
38+
mutable struct DenseJacobiIterable{matT,vecT,solT,rhsT}
3939
A::matT
40-
x::vecT
40+
x::solT
4141
next::vecT
42-
b::vecT
42+
b::rhsT
4343
maxiter::Int
4444
end
4545

@@ -93,10 +93,10 @@ function gauss_seidel!(x, A::AbstractMatrix, b; maxiter::Int=10)
9393
x
9494
end
9595

96-
mutable struct DenseGaussSeidelIterable{matT,vecT}
96+
mutable struct DenseGaussSeidelIterable{matT,solT,rhsT}
9797
A::matT
98-
x::vecT
99-
b::vecT
98+
x::solT
99+
b::rhsT
100100
maxiter::Int
101101
end
102102

@@ -149,11 +149,11 @@ function sor!(x, A::AbstractMatrix, b, ω::Real; maxiter::Int=10)
149149
x
150150
end
151151

152-
mutable struct DenseSORIterable{matT,vecT,numT}
152+
mutable struct DenseSORIterable{matT,solT,vecT,rhsT,numT}
153153
A::matT
154-
x::vecT
154+
x::solT
155155
tmp::vecT
156-
b::vecT
156+
b::rhsT
157157
ω::numT
158158
maxiter::Int
159159
end
@@ -207,11 +207,11 @@ function ssor!(x, A::AbstractMatrix, b, ω::Real; maxiter::Int=10)
207207
x
208208
end
209209

210-
mutable struct DenseSSORIterable{matT,vecT,numT}
210+
mutable struct DenseSSORIterable{matT,solT,vecT,rhsT,numT}
211211
A::matT
212-
x::vecT
212+
x::solT
213213
tmp::vecT
214-
b::vecT
214+
b::rhsT
215215
ω::numT
216216
maxiter::Int
217217
end

0 commit comments

Comments
 (0)