Skip to content

Commit 4dc1ccd

Browse files
authored
Merge pull request #2 from ranjanan/levels
Construct multilevel solver
2 parents 5a0c548 + 352cddf commit 4dc1ccd

File tree

5 files changed

+86
-29
lines changed

5 files changed

+86
-29
lines changed

src/AMG.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ export split_nodes, RS
99
include("gallery.jl")
1010
export poisson
1111

12+
include("smoother.jl")
13+
14+
include("multilevel.jl")
15+
1216
include("classical.jl")
1317

1418
end # module

src/classical.jl

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,38 @@ struct Solver{S,T,P,PS}
77
max_coarse::Int64
88
end
99

10-
struct Level{T}
11-
A::T
12-
end
13-
1410
function ruge_stuben(A::SparseMatrixCSC;
15-
strength = Classical(),
11+
strength = Classical(0.25),
1612
CF = RS(),
1713
presmoother = GaussSiedel(),
1814
postsmoother = GaussSiedel(),
1915
max_levels = 10,
2016
max_coarse = 500)
2117

22-
s = Solver(strength, CF, presmoother,
23-
postsmoother, max_levels, max_levels)
18+
s = Solver(strength, CF, presmoother,
19+
postsmoother, max_levels, max_levels)
2420

25-
levels = [Level(A)]
21+
levels = Vector{Level}()
2622

27-
while length(levels) < max_levels && size(levels[end].A, 1)
28-
extend_heirarchy!(levels, strength, CF, A)
23+
while length(levels) < max_levels
24+
A = extend_heirarchy!(levels, strength, CF, A)
25+
if size(levels[end].A, 1) < max_coarse
26+
break
2927
end
28+
end
29+
MultiLevel(levels)
3030
end
3131

3232
function extend_heirarchy!(levels::Vector{Level}, strength, CF, A)
3333
S = strength_of_connection(strength, A)
3434
splitting = split_nodes(CF, S)
3535
P, R = direct_interpolation(A, S, splitting)
36+
push!(levels, Level(A, P, R))
37+
A = R * A * P
3638
end
3739

3840
function direct_interpolation{T,V}(A::T, S::T, splitting::Vector{V})
39-
41+
4042
fill!(S.nzval, 1.)
4143
S = A .* S
4244
Pp = rs_direct_interpolation_pass1(S, A, splitting)
@@ -65,7 +67,8 @@ function rs_direct_interpolation_pass1(S, A, splitting)
6567
if splitting[i] == C_NODE
6668
nnz += 1
6769
else
68-
for jj = Sp[i]:Sp[i+1]-1
70+
for jj = Sp[i]:Sp[i+1]
71+
jj > length(Sj) && continue
6972
if splitting[Sj[jj]] == C_NODE && Sj[jj] != i
7073
nnz += 1
7174
end
@@ -92,7 +95,6 @@ function rs_direct_interpolation_pass1(S, A, splitting)
9295
Bj = zeros(Ti, Bp[end])
9396
Bx = zeros(Float64, Bp[end])
9497
n_nodes = size(A, 1)
95-
#Bp += 1
9698

9799
for i = 1:n_nodes
98100
if splitting[i] == C_NODE
@@ -101,7 +103,8 @@ function rs_direct_interpolation_pass1(S, A, splitting)
101103
else
102104
sum_strong_pos = 0
103105
sum_strong_neg = 0
104-
for jj = Sp[i]: Sp[i+1]-1
106+
for jj = Sp[i]: Sp[i+1]
107+
jj > length(Sj) && continue
105108
if splitting[Sj[jj]] == C_NODE && Sj[jj] != i
106109
if Sx[jj] < 0
107110
sum_strong_neg += Sx[jj]
@@ -115,6 +118,7 @@ function rs_direct_interpolation_pass1(S, A, splitting)
115118
sum_all_neg = 0
116119
diag = 0;
117120
for jj = Ap[i]:Ap[i+1]
121+
jj > length(Aj) && continue
118122
if Aj[jj] == i
119123
diag += Ax[jj]
120124
else
@@ -134,19 +138,20 @@ function rs_direct_interpolation_pass1(S, A, splitting)
134138
beta = 0
135139
end
136140

137-
neg_coeff = -alpha/diag;
138-
pos_coeff = -beta/diag;
141+
neg_coeff = -alpha / diag
142+
pos_coeff = -beta / diag
139143

140144
nnz = Bp[i]
141145
for jj = Sp[i]:Sp[i+1]
146+
jj > length(Sj) && continue
142147
if splitting[Sj[jj]] == C_NODE && Sj[jj] != i
143148
Bj[nnz] = Sj[jj]
144149
if Sx[jj] < 0
145150
Bx[nnz] = neg_coeff * Sx[jj]
146151
else
147152
Bx[nnz] = pos_coeff * Sx[jj]
148-
nnz += 1
149153
end
154+
nnz += 1
150155
end
151156
end
152157
end
@@ -159,6 +164,7 @@ function rs_direct_interpolation_pass1(S, A, splitting)
159164
sum += splitting[i]
160165
end
161166
for i = 1:Bp[n_nodes]
167+
Bj[i] == 0 && continue
162168
Bj[i] = m[Bj[i]]
163169
end
164170

src/multilevel.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,48 @@
11
struct Level{Ti,Tv}
22
A::SparseMatrixCSC{Ti,Tv}
3+
P::SparseMatrixCSC{Ti,Tv}
4+
R::SparseMatrixCSC{Ti,Tv}
5+
end
6+
7+
struct MultiLevel{L, S}
8+
levels::Vector{L}
9+
coarse_solver::S
10+
end
11+
12+
abstract type CoarseSolver end
13+
struct Pinv <: CoarseSolver
14+
end
15+
MultiLevel(l::Vector{Level}; coarse_solver = Pinv()) =
16+
MultiLevel(l, coarse_solver)
17+
18+
function Base.show(io::IO, ml::MultiLevel)
19+
op = operator_complexity(ml.levels)
20+
g = grid_complexity(ml.levels)
21+
c = ml.coarse_solver
22+
total_nnz = sum(nnz(level.A) for level in ml.levels)
23+
lstr = ""
24+
for (i, level) in enumerate(ml.levels)
25+
lstr = lstr *
26+
@sprintf " %2d %10d %10d [%5.2f%%]\n" i size(level.A, 1) nnz(level.A) (100 * nnz(level.A) / total_nnz)
27+
end
28+
str = """
29+
Multilevel Solver
30+
-----------------
31+
Operator Complexity: $op
32+
Grid Complexity: $g
33+
No. of Levels: $(size(ml.levels, 1))
34+
Coarse Solver: $c
35+
Level Unknowns NonZeros
36+
----- -------- --------
37+
$lstr
38+
"""
39+
print(io, str)
40+
end
41+
42+
function operator_complexity(ml::Vector{Level})
43+
sum(nnz(level.A) for level in ml) / nnz(ml[1].A)
44+
end
45+
46+
function grid_complexity(ml::Vector{Level})
47+
sum(size(level.A, 1) for level in ml) / size(ml[1].A, 1)
348
end

src/smoother.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
abstract type Smoother end
2+
struct GaussSiedel <: Smoother
3+
end

src/strength.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ abstract type Strength end
22
struct Classical{T} <: Strength
33
θ::T
44
end
5+
Classical(;θ = 0.25) = Classical(θ)
56

67
function strength_of_connection{T}(c::Classical{T}, A::SparseMatrixCSC)
78

@@ -14,8 +15,8 @@ function strength_of_connection{T}(c::Classical{T}, A::SparseMatrixCSC)
1415

1516
for i = 1:n
1617
neighbors = A[:,i]
17-
m = find_max_off_diag(neighbors, i)
18-
threshold = θ * m
18+
_m = find_max_off_diag(neighbors, i)
19+
threshold = θ * _m
1920
for j in nzrange(A, i)
2021
row = A.rowval[j]
2122
val = A.nzval[j]
@@ -26,19 +27,17 @@ function strength_of_connection{T}(c::Classical{T}, A::SparseMatrixCSC)
2627
end
2728
end
2829
end
29-
S = sparse(I, J, V)
30+
S = sparse(I, J, V, m, n)
3031

3132
scale_cols_by_largest_entry(S)
3233
end
3334

3435
function find_max_off_diag(neighbors, col)
35-
max_offdiag = 0
36-
for (i,v) in enumerate(neighbors)
37-
if col != i
38-
max_offdiag = max(max_offdiag, abs(v))
39-
end
36+
maxval = zero(eltype(neighbors))
37+
for i in 1:length(neighbors.nzval)
38+
maxval = max(maxval, ifelse(neighbors.nzind[i] == col, 0, abs(neighbors.nzval[i])))
4039
end
41-
max_offdiag
40+
return maxval
4241
end
4342

4443
function scale_cols_by_largest_entry(A::SparseMatrixCSC)
@@ -51,16 +50,16 @@ function scale_cols_by_largest_entry(A::SparseMatrixCSC)
5150

5251
k = 1
5352
for i = 1:n
54-
m = maximum(A[:,i])
53+
_m = maximum(A[:,i])
5554
for j in nzrange(A, i)
5655
row = A.rowval[j]
5756
val = A.nzval[j]
5857
I[k] = row
5958
J[k] = i
60-
V[k] = val / m
59+
V[k] = val / _m
6160
k += 1
6261
end
6362
end
6463

65-
sparse(I,J,V)
64+
sparse(I,J,V,m,n)
6665
end

0 commit comments

Comments
 (0)