Skip to content

Commit b025aef

Browse files
authored
Merge pull request #7 from ranjanan/morefix
Fix Interpolation bugs
2 parents f2ea5af + 0e875bb commit b025aef

File tree

4 files changed

+78
-38
lines changed

4 files changed

+78
-38
lines changed

src/classical.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ function ruge_stuben(A::SparseMatrixCSC;
2222

2323
while length(levels) < max_levels
2424
A = extend_heirarchy!(levels, strength, CF, A)
25-
if size(levels[end].A, 1) < max_coarse
25+
if size(A, 1) < max_coarse
2626
break
2727
end
2828
end
29-
MultiLevel(levels, presmoother, postsmoother)
29+
MultiLevel(levels, A, presmoother, postsmoother)
3030
end
3131

3232
function extend_heirarchy!(levels::Vector{Level}, strength, CF, A)
@@ -44,7 +44,7 @@ function direct_interpolation{T,V}(A::T, S::T, splitting::Vector{V})
4444
Pp = rs_direct_interpolation_pass1(S, A, splitting)
4545
Pp = Pp .+ 1
4646

47-
Px, Pj = rs_direct_interpolation_pass2(A, S, splitting, Pp)
47+
Px, Pj, Pp = rs_direct_interpolation_pass2(A, S, splitting, Pp)
4848

4949
# Px .= abs.(Px)
5050
Pj = Pj .+ 1
@@ -59,6 +59,7 @@ end
5959
function rs_direct_interpolation_pass1(S, A, splitting)
6060

6161
Bp = zeros(Int, size(A.colptr))
62+
T = S'
6263
#=Sp = S.colptr
6364
Sj = S.rowval
6465
n_nodes = size(A, 1)
@@ -82,8 +83,8 @@ function rs_direct_interpolation_pass1(S, A, splitting)
8283
if splitting[i] == C_NODE
8384
nnz += 1
8485
else
85-
for j in nzrange(S, i)
86-
row = S.rowval[j]
86+
for j in nzrange(T, i)
87+
row = T.rowval[j]
8788
if splitting[row] == C_NODE && row != i
8889
nnz += 1
8990
end
@@ -101,8 +102,9 @@ function rs_direct_interpolation_pass1(S, A, splitting)
101102
Bp::Vector{Ti})
102103

103104

104-
Bx = zeros(Float64, Bp[end])
105-
Bj = zeros(Ti, Bp[end])
105+
T = S'
106+
Bx = zeros(Float64, Bp[end] - 1)
107+
Bj = zeros(Ti, Bp[end] - 1)
106108

107109
n = size(A, 1)
108110

@@ -113,9 +115,9 @@ function rs_direct_interpolation_pass1(S, A, splitting)
113115
else
114116
sum_strong_pos = zero(Tv)
115117
sum_strong_neg = zero(Tv)
116-
for j in nzrange(S, i)
117-
row = S.rowval[j]
118-
sval = S.nzval[j]
118+
for j in nzrange(T, i)
119+
row = T.rowval[j]
120+
sval = T.nzval[j]
119121
if splitting[row] == C_NODE && row != i
120122
if sval < 0
121123
sum_strong_neg += sval
@@ -153,9 +155,9 @@ function rs_direct_interpolation_pass1(S, A, splitting)
153155

154156
nnz = Bp[i]
155157

156-
for j in nzrange(S, i)
157-
row = S.rowval[j]
158-
sval = S.nzval[j]
158+
for j in nzrange(T, i)
159+
row = T.rowval[j]
160+
sval = T.nzval[j]
159161
if splitting[row] == C_NODE && row != i
160162
Bj[nnz] = row
161163
if sval < 0
@@ -175,10 +177,15 @@ function rs_direct_interpolation_pass1(S, A, splitting)
175177
m[i] = sum
176178
sum += splitting[i]
177179
end
178-
for i = 1:Bp[n]
179-
Bj[i] == 0 && continue
180-
Bj[i] = m[Bj[i]]
181-
end
180+
#@show m
181+
#@show Bj
182+
#l = issymmetric(S)? Bp[n]: Bp[n] + 1
183+
#@show l
184+
#for i = 1:l
185+
#Bj[i] == 0 && continue
186+
# Bj[i] = m[Bj[i]]
187+
#end
188+
Bj .= m[Bj]
182189

183190
#=Ap = A.colptr
184191
Aj = A.rowval

src/multilevel.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ struct Level{Ti,Tv}
44
R::SparseMatrixCSC{Ti,Tv}
55
end
66

7-
struct MultiLevel{L, S, Pre, Post}
7+
struct MultiLevel{L, S, Pre, Post, Ti, Tv}
88
levels::Vector{L}
9+
final_A::SparseMatrixCSC{Ti,Tv}
910
coarse_solver::S
1011
presmoother::Pre
1112
postsmoother::Post
@@ -15,25 +16,28 @@ abstract type CoarseSolver end
1516
struct Pinv <: CoarseSolver
1617
end
1718

18-
MultiLevel(l::Vector{Level}, presmoother, postsmoother; coarse_solver = Pinv()) =
19-
MultiLevel(l, coarse_solver, presmoother, postsmoother)
19+
MultiLevel(l::Vector{Level}, A, presmoother, postsmoother; coarse_solver = Pinv()) =
20+
MultiLevel(l, A, coarse_solver, presmoother, postsmoother)
21+
Base.length(ml) = length(ml.levels) + 1
2022

2123
function Base.show(io::IO, ml::MultiLevel)
22-
op = operator_complexity(ml.levels)
23-
g = grid_complexity(ml.levels)
24+
op = operator_complexity(ml)
25+
g = grid_complexity(ml)
2426
c = ml.coarse_solver
25-
total_nnz = sum(nnz(level.A) for level in ml.levels)
27+
total_nnz = sum(nnz(level.A) for level in ml.levels) + nnz(ml.final_A)
2628
lstr = ""
2729
for (i, level) in enumerate(ml.levels)
2830
lstr = lstr *
2931
@sprintf " %2d %10d %10d [%5.2f%%]\n" i size(level.A, 1) nnz(level.A) (100 * nnz(level.A) / total_nnz)
3032
end
33+
lstr = lstr *
34+
@sprintf " %2d %10d %10d [%5.2f%%]" length(ml.levels) + 1 size(ml.final_A, 1) nnz(ml.final_A) (100 * nnz(ml.final_A) / total_nnz)
3135
str = """
3236
Multilevel Solver
3337
-----------------
34-
Operator Complexity: $op
35-
Grid Complexity: $g
36-
No. of Levels: $(size(ml.levels, 1))
38+
Operator Complexity: $(round(op, 3))
39+
Grid Complexity: $(round(g, 3))
40+
No. of Levels: $(length(ml))
3741
Coarse Solver: $c
3842
Level Unknowns NonZeros
3943
----- -------- --------
@@ -42,12 +46,12 @@ function Base.show(io::IO, ml::MultiLevel)
4246
print(io, str)
4347
end
4448

45-
function operator_complexity(ml::Vector{Level})
46-
sum(nnz(level.A) for level in ml) / nnz(ml[1].A)
49+
function operator_complexity(ml::MultiLevel)
50+
(sum(nnz(level.A) for level in ml.levels) + nnz(ml.final_A)) / nnz(ml.levels[1].A)
4751
end
4852

49-
function grid_complexity(ml::Vector{Level})
50-
sum(size(level.A, 1) for level in ml) / size(ml[1].A, 1)
53+
function grid_complexity(ml::MultiLevel)
54+
(sum(size(level.A, 1) for level in ml.levels) + size(ml.final_A, 1)) / size(ml.levels[1].A, 1)
5155
end
5256

5357
abstract type Cycle end
@@ -83,8 +87,8 @@ function __solve{T}(v::V, ml, x::Vector{T}, b::Vector{T}, lvl)
8387
coarse_b = ml.levels[lvl].R * res
8488
coarse_x = zeros(T, size(coarse_b))
8589

86-
if lvl == length(ml.levels) - 1
87-
coarse_x = coarse_solver(ml.coarse_solver, ml.levels[end].A, coarse_b)
90+
if lvl == length(ml.levels)
91+
coarse_x = coarse_solver(ml.coarse_solver, ml.final_A, coarse_b)
8892
else
8993
coarse_x = __solve(v, ml, coarse_x, coarse_b, lvl + 1)
9094
end

test/randlap.jld

42 KB
Binary file not shown.

test/runtests.jl

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using AMG
22
using Base.Test
33
using JLD
44

5+
@testset "AMG Tests" begin
6+
57
graph = load("test.jld")["G"]
68
ref_S = load("ref_S_test.jld")["G"]
79
ref_split = readdlm("ref_split_test.txt")
@@ -27,7 +29,6 @@ end
2729
# Ruge-Stuben splitting
2830
S = poisson(7)
2931
@test split_nodes(RS(), S) == [0, 1, 0, 1, 0, 1, 0]
30-
@show "buzz"
3132
srand(0)
3233
S = sprand(10,10,0.1); S = S + S'
3334
@test split_nodes(RS(), S) == [0, 1, 1, 0, 0, 0, 0, 0, 1, 1]
@@ -54,7 +55,9 @@ P, R = AMG.direct_interpolation(A, copy(A), splitting)
5455
0.0 1.0 0.0
5556
0.0 0.5 0.5
5657
0.0 0.0 1.0 ]
57-
58+
A = load("thing.jld")["G"]
59+
ml = ruge_stuben(A)
60+
@test size(ml.levels[2].A, 1) == 19
5861
end
5962

6063
@testset "Coarse Solver" begin
@@ -67,13 +70,31 @@ end
6770
A = poisson(1000)
6871
A = float.(A) #FIXME
6972
ml = AMG.ruge_stuben(A)
70-
@test length(ml.levels) == 8
71-
s = [1000, 500, 250, 125, 62, 31, 15, 7]
72-
n = [2998, 1498, 748, 373, 184, 91, 43, 19]
73-
for i = 1:8
73+
@test length(ml) == 8
74+
s = [1000, 500, 250, 125, 62, 31, 15]
75+
n = [2998, 1498, 748, 373, 184, 91, 43]
76+
for i = 1:7
7477
@test size(ml.levels[i].A, 1) == s[i]
7578
@test nnz(ml.levels[i].A) == n[i]
7679
end
80+
@test size(ml.final_A, 1) == 7
81+
@test nnz(ml.final_A) == 19
82+
83+
A = load("randlap.jld")["G"]
84+
ml = ruge_stuben(A)
85+
@test length(ml) == 3
86+
s = [100, 17]
87+
n = [2066, 289]
88+
for i = 1:2
89+
@test size(ml.levels[i].A, 1) == s[i]
90+
@test nnz(ml.levels[i].A) == n[i]
91+
end
92+
@test size(ml.final_A, 1) == 2
93+
@test nnz(ml.final_A) == 4
94+
@test round(AMG.operator_complexity(ml), 3) 1.142
95+
@test round(AMG.grid_complexity(ml), 3) 1.190
96+
97+
7798
end
7899

79100
@testset "Solver" begin
@@ -82,4 +103,12 @@ A = float.(A)
82103
ml = ruge_stuben(A)
83104
x = solve(ml, A * ones(1000))
84105
@test sum(abs2, x - ones(1000)) < 1e-10
106+
107+
A = load("randlap.jld")["G"]
108+
ml = ruge_stuben(A)
109+
x = solve(ml, A * ones(100))
110+
@test sum(abs2, x - zeros(100)) < 1e-10
111+
112+
end
113+
85114
end

0 commit comments

Comments
 (0)