Skip to content

Commit 089fede

Browse files
committed
Fix preconditioning
1 parent e5725d6 commit 089fede

File tree

4 files changed

+15
-24
lines changed

4 files changed

+15
-24
lines changed

src/multilevel.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,18 @@ abstract type Cycle end
5858
struct V <: Cycle
5959
end
6060

61-
function solve{T}(ml::MultiLevel, b::Vector{T}; maxiter = 100,
62-
cycle = V(),
63-
tol = 1e-5)
61+
function solve{T}(ml::MultiLevel, b::Vector{T},
62+
maxiter = 100,
63+
cycle = V(),
64+
tol = 1e-5)
6465
x = zeros(T, size(b))
6566
residuals = Vector{T}()
6667
A = ml.levels[1].A
6768
normb = norm(b)
6869
if normb != 0
6970
tol *= normb
7071
end
71-
push!(residuals, norm(b - A*x))
72+
push!(residuals, T(norm(b - A*x)))
7273

7374
lvl = 1
7475
while length(residuals) <= maxiter && residuals[end] > tol
@@ -77,7 +78,7 @@ function solve{T}(ml::MultiLevel, b::Vector{T}; maxiter = 100,
7778
else
7879
x = __solve(cycle, ml, x, b, lvl)
7980
end
80-
push!(residuals, norm(b - A * x))
81+
push!(residuals, T(norm(b - A * x)))
8182
end
8283
x
8384
end
@@ -103,4 +104,5 @@ function __solve{T}(v::V, ml, x::Vector{T}, b::Vector{T}, lvl)
103104
x
104105
end
105106

106-
coarse_solver(::Pinv, A, b) = pinv(full(A)) * b
107+
coarse_solver{Tv,Ti}(::Pinv, A::SparseMatrixCSC{Tv,Ti}, b::Vector{Tv}) =
108+
pinv(full(A)) * b

src/preconditioner.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ end
77
aspreconditioner(ml::MultiLevel) = Preconditioner(ml)
88

99
\(p::Preconditioner, b) = p * b
10-
*(p::Preconditioner, b) = solve(p.ml, b; cycle = V(), maxiter = 1, tol = 1e-12)
10+
*(p::Preconditioner, b) = solve(p.ml, b, 1, V(), 1e-12)
1111

1212
A_ldiv_B!(x, p::Preconditioner, b) = copy!(x, p \ b)
1313
A_mul_B!(b, p::Preconditioner, x) = A_mul_B!(b, p.ml.levels[1].A, x)

src/splitting.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,6 @@ const U_NODE = 2
44

55
struct RS
66
end
7-
#=function split_nodes(::RS, S)
8-
n = size(S, 1)
9-
for i = 1:n
10-
for j in nzrange(S, i)
11-
row = S.rowval[j]
12-
if row == i
13-
S.nzval[j] = 0
14-
end
15-
end
16-
end
17-
RS_CF_splitting(sparse(i,j,v,n,n), sparse(j,i,v,n,n))
18-
end=#
197

208
function remove_diag!(a)
219
n = size(a, 1)

test/runtests.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using AMG
22
using Base.Test
33
using JLD
44
using IterativeSolvers
5+
import AMG: V, coarse_solver, Pinv, Classical
56

67
@testset "AMG Tests" begin
78

@@ -35,7 +36,7 @@ S = sprand(10,10,0.1); S = S + S'
3536
@test split_nodes(RS(), S) == [0, 1, 1, 0, 0, 0, 0, 0, 1, 1]
3637

3738
a = load("thing.jld")["G"]
38-
S, T = AMG.strength_of_connection(AMG.Classical(0.25), a)
39+
S, T = AMG.strength_of_connection(Classical(0.25), a)
3940
@test split_nodes(RS(), S) == [0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0,
4041
0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0,
4142
1, 0]
@@ -62,9 +63,9 @@ ml = ruge_stuben(A)
6263
end
6364

6465
@testset "Coarse Solver" begin
65-
A = poisson(10)
66+
A = float.(poisson(10))
6667
b = A * ones(10)
67-
@test sum(abs2, AMG.coarse_solver(AMG.Pinv(), A, b) - ones(10)) < 1e-6
68+
@test sum(abs2, coarse_solver(Pinv(), A, b) - ones(10)) < 1e-6
6869
end
6970

7071
@testset "Multilevel" begin
@@ -120,7 +121,7 @@ p = aspreconditioner(ml)
120121
b = zeros(n)
121122
b[1] = 1
122123
b[2] = -1
123-
x = solve(p.ml, A * ones(n); maxiter = 1, tol = 1e-12)
124+
x = solve(p.ml, A * ones(n), 1, V(), 1e-12)
124125
diff = x - [ 1.88664780e-16, 2.34982727e-16, 2.33917697e-16,
125126
8.77869044e-17, 7.16783490e-17, 1.43415460e-16,
126127
3.69199021e-17, 9.70950385e-17, 4.77034895e-17,
@@ -138,7 +139,7 @@ diff = x - [ 1.88664780e-16, 2.34982727e-16, 2.33917697e-16,
138139
-6.76965535e-16, -7.00643227e-16, -6.23581397e-16,
139140
-7.03016682e-16]
140141
@test sum(abs2, diff) < 1e-8
141-
x = solve(p.ml, b; maxiter = 1, tol = 1e-12)
142+
x = solve(p.ml, b, 1, V(), 1e-12)
142143
diff = x - [ 0.76347046, -0.5498286 , -0.2705487 , -0.15047352, -0.10248021,
143144
0.60292674, -0.11497073, -0.08460548, -0.06931461, 0.38230708,
144145
-0.055664 , -0.04854558, -0.04577031, 0.09964325, 0.01825624,

0 commit comments

Comments
 (0)