Skip to content

Commit 4fb1768

Browse files
committed
Some fixes and improvements to multilevel
1 parent 4dc1ccd commit 4fb1768

File tree

6 files changed

+192
-17
lines changed

6 files changed

+192
-17
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
julia 0.6
2+
IterativeSolvers 0.4.1

src/AMG.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module AMG
22

3+
import IterativeSolvers: gauss_seidel!
4+
35
include("strength.jl")
46
export strength_of_connection, Classical
57

src/classical.jl

Lines changed: 99 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ end
1010
function ruge_stuben(A::SparseMatrixCSC;
1111
strength = Classical(0.25),
1212
CF = RS(),
13-
presmoother = GaussSiedel(),
14-
postsmoother = GaussSiedel(),
13+
presmoother = GaussSeidel(),
14+
postsmoother = GaussSeidel(),
1515
max_levels = 10,
16-
max_coarse = 500)
16+
max_coarse = 10)
1717

1818
s = Solver(strength, CF, presmoother,
1919
postsmoother, max_levels, max_levels)
@@ -26,7 +26,7 @@ function ruge_stuben(A::SparseMatrixCSC;
2626
break
2727
end
2828
end
29-
MultiLevel(levels)
29+
MultiLevel(levels, presmoother, postsmoother)
3030
end
3131

3232
function extend_heirarchy!(levels::Vector{Level}, strength, CF, A)
@@ -46,7 +46,7 @@ function direct_interpolation{T,V}(A::T, S::T, splitting::Vector{V})
4646

4747
Px, Pj = rs_direct_interpolation_pass2(A, S, splitting, Pp)
4848

49-
Px .= abs.(Px)
49+
# Px .= abs.(Px)
5050
Pj .= Pj .+ 1
5151

5252
R = SparseMatrixCSC(maximum(Pj), size(A, 1), Pp, Pj, Px)
@@ -86,7 +86,86 @@ function rs_direct_interpolation_pass1(S, A, splitting)
8686
Bp::Vector{Ti})
8787

8888

89-
Ap = A.colptr
89+
Bx = zeros(Float64, Bp[end])
90+
Bj = zeros(Ti, Bp[end])
91+
92+
n = size(A, 1)
93+
94+
for i = 1:n
95+
if splitting[i] == C_NODE
96+
Bj[Bp[i]] = i
97+
Bx[Bp[i]] = 1
98+
else
99+
sum_strong_pos = zero(Tv)
100+
sum_strong_neg = zero(Tv)
101+
for j in nzrange(S, i)
102+
row = S.rowval[j]
103+
sval = S.nzval[j]
104+
if splitting[row] == C_NODE && row != i
105+
if sval < 0
106+
sum_strong_neg += sval
107+
else
108+
sum_strong_pos += sval
109+
end
110+
end
111+
end
112+
sum_all_pos = zero(Tv)
113+
sum_all_neg = zero(Tv)
114+
diag = zero(Tv)
115+
for j in nzrange(A, i)
116+
row = A.rowval[j]
117+
aval = A.nzval[j]
118+
if row == i
119+
diag += aval
120+
else
121+
if aval < 0
122+
sum_all_neg += aval
123+
else
124+
sum_all_pos += aval
125+
end
126+
end
127+
end
128+
alpha = sum_all_neg / sum_strong_neg
129+
beta = sum_all_pos / sum_strong_pos
130+
131+
if sum_strong_pos == 0
132+
diag += sum_all_pos
133+
beta = 0
134+
end
135+
136+
neg_coeff = -1 * alpha / diag
137+
pos_coeff = -1 * beta / diag
138+
139+
nnz = Bp[i]
140+
141+
for j in nzrange(S, i)
142+
row = S.rowval[j]
143+
sval = S.nzval[j]
144+
if splitting[row] == C_NODE && row != i
145+
Bj[nnz] = row
146+
if sval < 0
147+
Bx[nnz] = neg_coeff * sval
148+
else
149+
Bx[nnz] = pos_coeff * sval
150+
end
151+
nnz += 1
152+
end
153+
end
154+
end
155+
end
156+
157+
m = zeros(Ti, n)
158+
sum = 0
159+
for i = 1:n
160+
m[i] = sum
161+
sum += splitting[i]
162+
end
163+
for i = 1:Bp[n]
164+
Bj[i] == 0 && continue
165+
Bj[i] = m[Bj[i]]
166+
end
167+
168+
#=Ap = A.colptr
90169
Aj = A.rowval
91170
Ax = A.nzval
92171
Sp = S.colptr
@@ -116,30 +195,37 @@ function rs_direct_interpolation_pass1(S, A, splitting)
116195
117196
sum_all_pos = 0
118197
sum_all_neg = 0
119-
diag = 0;
198+
diag = 0
120199
for jj = Ap[i]:Ap[i+1]
121200
jj > length(Aj) && continue
122201
if Aj[jj] == i
202+
@show Ax[jj]
123203
diag += Ax[jj]
124204
else
125205
if Ax[jj] < 0
126-
sum_all_neg += Ax[jj];
206+
sum_all_neg += Ax[jj]
127207
else
128-
sum_all_pos += Ax[jj];
208+
sum_all_pos += Ax[jj]
129209
end
130210
end
131211
end
132212
133213
alpha = sum_all_neg / sum_strong_neg
134214
beta = sum_all_pos / sum_strong_pos
215+
@show alpha
216+
@show beta
217+
@show diag
135218
136219
if sum_strong_pos == 0
137220
diag += sum_all_pos
138221
beta = 0
139222
end
140223
141-
neg_coeff = -alpha / diag
142-
pos_coeff = -beta / diag
224+
neg_coeff = -1 * alpha / diag
225+
pos_coeff = -1 * beta / diag
226+
227+
@show neg_coeff
228+
@show pos_coeff
143229
144230
nnz = Bp[i]
145231
for jj = Sp[i]:Sp[i+1]
@@ -151,6 +237,7 @@ function rs_direct_interpolation_pass1(S, A, splitting)
151237
else
152238
Bx[nnz] = pos_coeff * Sx[jj]
153239
end
240+
@show Bx[nnz]
154241
nnz += 1
155242
end
156243
end
@@ -166,7 +253,7 @@ function rs_direct_interpolation_pass1(S, A, splitting)
166253
for i = 1:Bp[n_nodes]
167254
Bj[i] == 0 && continue
168255
Bj[i] = m[Bj[i]]
169-
end
256+
end =#
170257

171258
Bx, Bj, Bp
172259
end

src/multilevel.jl

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@ struct Level{Ti,Tv}
44
R::SparseMatrixCSC{Ti,Tv}
55
end
66

7-
struct MultiLevel{L, S}
7+
struct MultiLevel{L, S, Pre, Post}
88
levels::Vector{L}
99
coarse_solver::S
10+
presmoother::Pre
11+
postsmoother::Post
1012
end
1113

1214
abstract type CoarseSolver end
1315
struct Pinv <: CoarseSolver
1416
end
15-
MultiLevel(l::Vector{Level}; coarse_solver = Pinv()) =
16-
MultiLevel(l, coarse_solver)
17+
18+
MultiLevel(l::Vector{Level}, presmoother, postsmoother; coarse_solver = Pinv()) =
19+
MultiLevel(l, coarse_solver, presmoother, postsmoother)
1720

1821
function Base.show(io::IO, ml::MultiLevel)
1922
op = operator_complexity(ml.levels)
@@ -46,3 +49,53 @@ end
4649
function grid_complexity(ml::Vector{Level})
4750
sum(size(level.A, 1) for level in ml) / size(ml[1].A, 1)
4851
end
52+
53+
abstract type Cycle end
54+
struct V <: Cycle
55+
end
56+
57+
function solve{T}(ml::MultiLevel, b::Vector{T}; maxiter = 100,
58+
cycle = V(),
59+
tol = 1e-5)
60+
x = zeros(T, size(b))
61+
residuals = Vector{T}()
62+
A = ml.levels[1].A
63+
normb = norm(b)
64+
push!(residuals, norm(b - A*x))
65+
66+
lvl = 1
67+
while length(residuals) <= maxiter && residuals[end] > tol
68+
if length(ml.levels) == 1
69+
x = coarse_solver(ml.coarse_solver, A, b)
70+
else
71+
x = __solve(cycle, ml, x, b, lvl, residuals)
72+
end
73+
end
74+
x
75+
end
76+
function __solve{T}(v::V, ml, x::Vector{T}, b::Vector{T}, lvl, residuals)
77+
78+
@show lvl
79+
A = ml.levels[lvl].A
80+
presmoother!(ml.presmoother, A, x, b)
81+
82+
res = b - A * x
83+
@show norm(res)
84+
push!(residuals, norm(res))
85+
86+
coarse_b = ml.levels[lvl].R * res
87+
88+
if lvl == length(ml.levels) - 1
89+
coarse_x = coarse_solver(ml.coarse_solver, ml.levels[end].A, coarse_b)
90+
else
91+
coarse_x = __solve(v, ml, coarse_x, coarse_b, lvl + 1)
92+
end
93+
94+
x .+= ml.levels[lvl].P * coarse_x
95+
96+
postsmoother!(ml.postsmoother, A, x, b)
97+
98+
x
99+
end
100+
101+
coarse_solver(::Pinv, A, b) = pinv(full(A)) * b

src/smoother.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
11
abstract type Smoother end
2-
struct GaussSiedel <: Smoother
2+
abstract type Sweep end
3+
struct SymmetricSweep <: Sweep
34
end
5+
struct ForwardSweep <: Sweep
6+
end
7+
struct GaussSeidel{S} <: Smoother
8+
sweep::S
9+
end
10+
GaussSeidel(;sweep = ForwardSweep()) = GaussSeidel(sweep)
11+
12+
presmoother!(s, A, x, b) = smoother(s, s.sweep, A, x, b)
13+
postsmoother!(s, A, x, b) = smoother(s, s.sweep, A, x, b)
14+
15+
smoother(s::GaussSeidel, ::ForwardSweep, A, x, b) =
16+
gauss_seidel!(x, A, b, maxiter = 1)

test/runtests.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,30 @@ end
3232
using AMG
3333
A = poisson(5)
3434
splitting = [1,0,1,0,1]
35-
P, R = AMG.direct_interpolation(A, A, splitting)
35+
P, R = AMG.direct_interpolation(A, copy(A), splitting)
3636
@test P == [ 1.0 0.0 0.0
3737
0.5 0.5 0.0
3838
0.0 1.0 0.0
3939
0.0 0.5 0.5
4040
0.0 0.0 1.0 ]
4141

4242
end
43+
44+
@testset "Coarse Solver" begin
45+
A = poisson(10)
46+
b = A * ones(10)
47+
@test sum(abs, AMG.coarse_solver(AMG.Pinv(), A, b) - ones(10)) < 1e-6
48+
end
49+
50+
@testset "Multilevel" begin
51+
A = poisson(1000)
52+
A = float.(A) #FIXME
53+
ml = AMG.ruge_stuben(A)
54+
@test length(ml.levels) == 8
55+
s = [1000, 500, 250, 125, 62, 31, 15, 7]
56+
n = [2998, 1498, 748, 373, 184, 91, 43, 19]
57+
for i = 1:8
58+
@test size(ml.levels[i].A, 1) == s[i]
59+
@test nnz(ml.levels[i].A) == n[i]
60+
end
61+
end

0 commit comments

Comments
 (0)