Skip to content

Commit 08f9724

Browse files
authored
Merge pull request #21 from ranjanan/single
Support single precision
2 parents c43419c + f42c843 commit 08f9724

File tree

8 files changed

+48
-31
lines changed

8 files changed

+48
-31
lines changed

src/aggregate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ function aggregation(::StandardAggregation, S)
9595

9696
Tp = collect(1:n+1)
9797
x .= x .+ 1
98-
Tx = ones(length(x))
98+
Tx = ones(eltype(S), length(x))
9999

100100
SparseMatrixCSC(N, M, Tp, x, Tx)
101101
end

src/aggregation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function smoothed_aggregation{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti},
1+
function smoothed_aggregation(A::SparseMatrixCSC{T,V},
22
symmetry = HermitianSymmetry(),
33
strength = SymmetricStrength(),
44
aggregate = StandardAggregation(),
@@ -10,12 +10,12 @@ function smoothed_aggregation{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti},
1010
max_coarse = 10,
1111
diagonal_dominance = false,
1212
keep = false,
13-
coarse_solver = Pinv())
13+
coarse_solver = Pinv()) where {T,V}
1414

1515

1616
n = size(A, 1)
1717
# B = kron(ones(n, 1), eye(1))
18-
B = ones(n)
18+
B = ones(T,n)
1919

2020
#=max_levels, max_coarse, strength =
2121
levelize_strength_or_aggregation(max_levels, max_coarse, strength)
@@ -28,7 +28,7 @@ function smoothed_aggregation{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti},
2828
# agg = [aggregate for _ in 1:max_levels - 1]
2929
# sm = [smooth for _ in 1:max_levels]
3030

31-
levels = Vector{Level{Tv,Ti}}()
31+
levels = Vector{Level{T,V}}()
3232
bsr_flag = false
3333

3434
while length(levels) + 1 < max_levels && size(A, 1) > max_coarse

src/classical.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ end
3939

4040
function direct_interpolation(A, T, splitting)
4141

42-
fill!(T.nzval, 1.)
42+
fill!(T.nzval, eltype(A)(1.))
4343
T .= A .* T
4444
Pp = rs_direct_interpolation_pass1(T, A, splitting)
4545
Pp .= Pp .+ 1
@@ -77,13 +77,13 @@ function rs_direct_interpolation_pass1(T, A, splitting)
7777
end
7878

7979

80-
function rs_direct_interpolation_pass2{Tv, Ti}(A::SparseMatrixCSC{Tv,Ti},
80+
function rs_direct_interpolation_pass2(A::SparseMatrixCSC{Tv,Ti},
8181
T::SparseMatrixCSC{Tv, Ti},
8282
splitting::Vector{Ti},
83-
Bp::Vector{Ti})
83+
Bp::Vector{Ti}) where {Tv,Ti}
8484

85-
86-
Bx = zeros(Float64, Bp[end] - 1)
85+
86+
Bx = zeros(Tv, Bp[end] - 1)
8787
Bj = zeros(Ti, Bp[end] - 1)
8888

8989
n = size(A, 1)
@@ -152,7 +152,7 @@ function rs_direct_interpolation_pass1(T, A, splitting)
152152
end
153153

154154
m = zeros(Ti, n)
155-
sum = zero(eltype(m))
155+
sum = zero(Ti)
156156
for i = 1:n
157157
m[i] = sum
158158
sum += splitting[i]

src/multilevel.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
struct Level{Ti,Tv}
2-
A::SparseMatrixCSC{Ti,Tv}
3-
P::SparseMatrixCSC{Ti,Tv}
4-
R::SparseMatrixCSC{Ti,Tv}
1+
struct Level{T,V}
2+
A::SparseMatrixCSC{T,V}
3+
P::SparseMatrixCSC{T,V}
4+
R::SparseMatrixCSC{T,V}
55
end
66

77
struct MultiLevel{S, Pre, Post, Ti, Tv}
@@ -66,8 +66,10 @@ function solve{T}(ml::MultiLevel, b::Vector{T},
6666
tol = 1e-5;
6767
verbose = false,
6868
log = false)
69-
x = zeros(T, size(b))
70-
residuals = Vector{T}()
69+
V = promote_type(eltype(ml.levels[1].A), eltype(b))
70+
x = zeros(V, size(b))
71+
tol = eltype(b)(tol)
72+
residuals = Vector{V}()
7173
A = length(ml) == 1 ? ml.final_A : ml.levels[1].A
7274
normb = norm(b)
7375
if normb != 0
@@ -91,14 +93,14 @@ function solve{T}(ml::MultiLevel, b::Vector{T},
9193
return x
9294
end
9395
end
94-
function __solve{T}(v::V, ml, x::Vector{T}, b::Vector{T}, lvl)
96+
function __solve(v::V, ml, x, b, lvl)
9597

9698
A = ml.levels[lvl].A
9799
presmoother!(ml.presmoother, A, x, b)
98100

99101
res = b - A * x
100102
coarse_b = ml.levels[lvl].R * res
101-
coarse_x = zeros(T, size(coarse_b))
103+
coarse_x = zeros(eltype(coarse_b), size(coarse_b))
102104

103105
if lvl == length(ml.levels)
104106
coarse_x = coarse_solver(ml.coarse_solver, ml.final_A, coarse_b)
@@ -113,5 +115,4 @@ function __solve{T}(v::V, ml, x::Vector{T}, b::Vector{T}, lvl)
113115
x
114116
end
115117

116-
coarse_solver{Tv,Ti}(::Pinv, A::SparseMatrixCSC{Tv,Ti}, b::Vector{Tv}) =
117-
pinv(full(A)) * b
118+
coarse_solver(::Pinv, A, b) = pinv(full(A)) * b

src/smoother.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ smoother!(s::GaussSeidel, ::BackwardSweep, A, x, b) =
3232
gs!(A, b, x, size(A,1), -1, 1)
3333

3434

35-
function gs!{T,Ti}(A::SparseMatrixCSC{T,Ti}, b::Vector{T}, x::Vector{T}, start, step, stop)
35+
function gs!(A, b, x, start, step, stop)
3636
n = size(A, 1)
3737
z = zero(eltype(A))
3838
for i = start:step:stop
@@ -110,7 +110,7 @@ end
110110
function weight(::DiagonalWeighting, S, ω)
111111
D_inv = 1 ./ diag(S)
112112
D_inv_S = scale_rows(S, D_inv)
113-
(ω / approximate_spectral_radius(D_inv_S)) * D_inv_S
113+
(eltype(S)(ω) / approximate_spectral_radius(D_inv_S)) * D_inv_S
114114
#(ω) * D_inv_S
115115
end
116116

src/strength.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ struct Classical{T} <: Strength
44
end
55
Classical(;θ = 0.25) = Classical(θ)
66

7-
function strength_of_connection{T, Ti, Tv}(c::Classical{T}, A::SparseMatrixCSC{Tv, Ti})
7+
function strength_of_connection(c::Classical{T},
8+
A::SparseMatrixCSC{Tv,Ti}) where {T,Ti,Tv}
89

910
θ = c.θ
1011

@@ -77,13 +78,11 @@ SymmetricStrength() = SymmetricStrength(0.)
7778

7879
function strength_of_connection{T}(s::SymmetricStrength{T}, A, bsr_flag = false)
7980

80-
81-
8281
θ = s.θ
8382

8483
if bsr_flag && θ == 0
8584
S = SparseMatrixCSC(size(A)...,
86-
A.colptr, A.rowval, ones(size(A.rowval)))
85+
A.colptr, A.rowval, ones(eltype(A), size(A.rowval)))
8786
return S
8887
else
8988
S = deepcopy(A)

src/utils.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ function approximate_spectral_radius(A, tol = 0.01,
55
symmetric = false
66

77
# Initial guess
8-
v0 = rand(size(A,1))
8+
v0 = rand(eltype(A), size(A,1))
99
maxiter = min(size(A, 1), maxiter)
1010
ev = zeros(eltype(A), maxiter)
1111
max_index = 0
12-
X = zeros(size(A,1), maxiter)
12+
X = zeros(eltype(A), size(A,1), maxiter)
1313

1414
for i in 1:restart+1
1515
evect, ev, H, V, flag =
@@ -90,8 +90,7 @@ function approximate_eigenvalues(A, tol, maxiter, symmetric, v0)
9090
scale!(w, 1/H[j+1,j])
9191
push!(V, w)
9292
end
93-
94-
Eigs, Vects = eig(H[1:maxiter, 1:maxiter], eye(maxiter))
93+
Eigs, Vects = eig(H[1:maxiter, 1:maxiter], eye(eltype(A), maxiter))
9594

9695
Vects, Eigs, H, V, flag
9796
end

test/runtests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ end
5151
# Direct Interpolation
5252
using AMG
5353
A = poisson(5)
54+
A = Float64.(A)
5455
splitting = [1,0,1,0,1]
5556
P, R = AMG.direct_interpolation(A, copy(A), splitting)
5657
@test P == [ 1.0 0.0 0.0
@@ -216,4 +217,21 @@ diff = x - [0.775725, -0.571202, -0.290989, -0.157001, -0.106981, 0.622652,
216217

217218
end
218219

220+
@testset "Precision" begin
221+
222+
a = poisson(100)
223+
b = rand(size(a,1))
224+
225+
# Iterate through all types
226+
for (T,V) in ((Float64, Float64), (Float32,Float32),
227+
(Float64,Float32), (Float32,Float64))
228+
a = T.(a)
229+
ml = smoothed_aggregation(a)
230+
b = V.(b)
231+
c = cg(a, b, maxiter = 10)
232+
@test eltype(solve(ml, b)) == eltype(c)
233+
end
234+
219235
end
236+
237+
end

0 commit comments

Comments
 (0)