Skip to content

Commit a3aa1eb

Browse files
authored
Merge pull request #8 from ranjanan/strength
Some clean up
2 parents b025aef + 78c4a87 commit a3aa1eb

File tree

5 files changed

+41
-238
lines changed

5 files changed

+41
-238
lines changed

src/classical.jl

Lines changed: 14 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ struct Solver{S,T,P,PS}
77
max_coarse::Int64
88
end
99

10-
function ruge_stuben(A::SparseMatrixCSC;
10+
function ruge_stuben{Ti,Tv}(A::SparseMatrixCSC{Ti,Tv};
1111
strength = Classical(0.25),
1212
CF = RS(),
1313
presmoother = GaussSeidel(),
@@ -18,7 +18,7 @@ function ruge_stuben(A::SparseMatrixCSC;
1818
s = Solver(strength, CF, presmoother,
1919
postsmoother, max_levels, max_levels)
2020

21-
levels = Vector{Level}()
21+
levels = Vector{Level{Ti,Tv}}()
2222

2323
while length(levels) < max_levels
2424
A = extend_heirarchy!(levels, strength, CF, A)
@@ -29,25 +29,24 @@ function ruge_stuben(A::SparseMatrixCSC;
2929
MultiLevel(levels, A, presmoother, postsmoother)
3030
end
3131

32-
function extend_heirarchy!(levels::Vector{Level}, strength, CF, A)
33-
S = strength_of_connection(strength, A)
32+
function extend_heirarchy!{Ti,Tv}(levels::Vector{Level{Ti,Tv}}, strength, CF, A::SparseMatrixCSC{Ti,Tv})
33+
S, T = strength_of_connection(strength, A)
3434
splitting = split_nodes(CF, S)
35-
P, R = direct_interpolation(A, S, splitting)
35+
P, R = direct_interpolation(A, T, splitting)
3636
push!(levels, Level(A, P, R))
3737
A = R * A * P
3838
end
3939

40-
function direct_interpolation{T,V}(A::T, S::T, splitting::Vector{V})
40+
function direct_interpolation(A, T, splitting)
4141

42-
fill!(S.nzval, 1.)
43-
S = A .* S
44-
Pp = rs_direct_interpolation_pass1(S, A, splitting)
45-
Pp = Pp .+ 1
42+
fill!(T.nzval, 1.)
43+
T = A .* T
44+
Pp = rs_direct_interpolation_pass1(T, A, splitting)
45+
Pp .= Pp .+ 1
4646

47-
Px, Pj, Pp = rs_direct_interpolation_pass2(A, S, splitting, Pp)
47+
Px, Pj, Pp = rs_direct_interpolation_pass2(A, T, splitting, Pp)
4848

49-
# Px .= abs.(Px)
50-
Pj = Pj .+ 1
49+
Pj .= Pj .+ 1
5150

5251
R = SparseMatrixCSC(maximum(Pj), size(A, 1), Pp, Pj, Px)
5352
P = R'
@@ -56,27 +55,9 @@ function direct_interpolation{T,V}(A::T, S::T, splitting::Vector{V})
5655
end
5756

5857

59-
function rs_direct_interpolation_pass1(S, A, splitting)
58+
function rs_direct_interpolation_pass1(T, A, splitting)
6059

6160
Bp = zeros(Int, size(A.colptr))
62-
T = S'
63-
#=Sp = S.colptr
64-
Sj = S.rowval
65-
n_nodes = size(A, 1)
66-
nnz = 0
67-
for i = 1:n_nodes
68-
if splitting[i] == C_NODE
69-
nnz += 1
70-
else
71-
for jj = Sp[i]:Sp[i+1]
72-
jj > length(Sj) && continue
73-
if splitting[Sj[jj]] == C_NODE && Sj[jj] != i
74-
nnz += 1
75-
end
76-
end
77-
end
78-
Bp[i+1] = nnz
79-
end=#
8061
n = size(A, 1)
8162
nnz = 0
8263
for i = 1:n
@@ -97,12 +78,11 @@ function rs_direct_interpolation_pass1(S, A, splitting)
9778

9879

9980
function rs_direct_interpolation_pass2{Tv, Ti}(A::SparseMatrixCSC{Tv,Ti},
100-
S::SparseMatrixCSC{Tv,Ti},
81+
T::SparseMatrixCSC{Tv, Ti},
10182
splitting::Vector{Ti},
10283
Bp::Vector{Ti})
10384

10485

105-
T = S'
10686
Bx = zeros(Float64, Bp[end] - 1)
10787
Bj = zeros(Ti, Bp[end] - 1)
10888

@@ -177,105 +157,7 @@ function rs_direct_interpolation_pass1(S, A, splitting)
177157
m[i] = sum
178158
sum += splitting[i]
179159
end
180-
#@show m
181-
#@show Bj
182-
#l = issymmetric(S)? Bp[n]: Bp[n] + 1
183-
#@show l
184-
#for i = 1:l
185-
#Bj[i] == 0 && continue
186-
# Bj[i] = m[Bj[i]]
187-
#end
188160
Bj .= m[Bj]
189161

190-
#=Ap = A.colptr
191-
Aj = A.rowval
192-
Ax = A.nzval
193-
Sp = S.colptr
194-
Sj = S.rowval
195-
Sx = S.nzval
196-
Bj = zeros(Ti, Bp[end])
197-
Bx = zeros(Float64, Bp[end])
198-
n_nodes = size(A, 1)
199-
200-
for i = 1:n_nodes
201-
if splitting[i] == C_NODE
202-
Bj[Bp[i]] = i
203-
Bx[Bp[i]] = 1
204-
else
205-
sum_strong_pos = 0
206-
sum_strong_neg = 0
207-
for jj = Sp[i]: Sp[i+1]
208-
jj > length(Sj) && continue
209-
if splitting[Sj[jj]] == C_NODE && Sj[jj] != i
210-
if Sx[jj] < 0
211-
sum_strong_neg += Sx[jj]
212-
else
213-
sum_strong_pos += Sx[jj]
214-
end
215-
end
216-
end
217-
218-
sum_all_pos = 0
219-
sum_all_neg = 0
220-
diag = 0
221-
for jj = Ap[i]:Ap[i+1]
222-
jj > length(Aj) && continue
223-
if Aj[jj] == i
224-
@show Ax[jj]
225-
diag += Ax[jj]
226-
else
227-
if Ax[jj] < 0
228-
sum_all_neg += Ax[jj]
229-
else
230-
sum_all_pos += Ax[jj]
231-
end
232-
end
233-
end
234-
235-
alpha = sum_all_neg / sum_strong_neg
236-
beta = sum_all_pos / sum_strong_pos
237-
@show alpha
238-
@show beta
239-
@show diag
240-
241-
if sum_strong_pos == 0
242-
diag += sum_all_pos
243-
beta = 0
244-
end
245-
246-
neg_coeff = -1 * alpha / diag
247-
pos_coeff = -1 * beta / diag
248-
249-
@show neg_coeff
250-
@show pos_coeff
251-
252-
nnz = Bp[i]
253-
for jj = Sp[i]:Sp[i+1]
254-
jj > length(Sj) && continue
255-
if splitting[Sj[jj]] == C_NODE && Sj[jj] != i
256-
Bj[nnz] = Sj[jj]
257-
if Sx[jj] < 0
258-
Bx[nnz] = neg_coeff * Sx[jj]
259-
else
260-
Bx[nnz] = pos_coeff * Sx[jj]
261-
end
262-
@show Bx[nnz]
263-
nnz += 1
264-
end
265-
end
266-
end
267-
end
268-
269-
m = zeros(Ti, n_nodes)
270-
sum = 0
271-
for i = 1:n_nodes
272-
m[i] = sum
273-
sum += splitting[i]
274-
end
275-
for i = 1:Bp[n_nodes]
276-
Bj[i] == 0 && continue
277-
Bj[i] = m[Bj[i]]
278-
end =#
279-
280162
Bx, Bj, Bp
281163
end

src/multilevel.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ struct Level{Ti,Tv}
44
R::SparseMatrixCSC{Ti,Tv}
55
end
66

7-
struct MultiLevel{L, S, Pre, Post, Ti, Tv}
8-
levels::Vector{L}
7+
struct MultiLevel{S, Pre, Post, Ti, Tv}
8+
levels::Vector{Level{Ti,Tv}}
99
final_A::SparseMatrixCSC{Ti,Tv}
1010
coarse_solver::S
1111
presmoother::Pre
@@ -16,7 +16,7 @@ abstract type CoarseSolver end
1616
struct Pinv <: CoarseSolver
1717
end
1818

19-
MultiLevel(l::Vector{Level}, A, presmoother, postsmoother; coarse_solver = Pinv()) =
19+
MultiLevel{Ti,Tv}(l::Vector{Level{Ti,Tv}}, A::SparseMatrixCSC{Ti,Tv}, presmoother, postsmoother; coarse_solver = Pinv()) =
2020
MultiLevel(l, A, coarse_solver, presmoother, postsmoother)
2121
Base.length(ml) = length(ml.levels) + 1
2222

src/splitting.jl

Lines changed: 19 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,31 @@ const U_NODE = 2
44

55
struct RS
66
end
7+
#=function split_nodes(::RS, S)
8+
n = size(S, 1)
9+
for i = 1:n
10+
for j in nzrange(S, i)
11+
row = S.rowval[j]
12+
if row == i
13+
S.nzval[j] = 0
14+
end
15+
end
16+
end
17+
i, j, v = findnz(S)
18+
RS_CF_splitting(sparse(i,j,v,n,n), sparse(j,i,v,n,n))
19+
end=#
20+
function split_nodes(::RS, S)
21+
T = S'
22+
RS_CF_splitting(S - spdiagm(diag(S)), T - spdiagm(diag(T)))
23+
end
724

8-
split_nodes(::RS, S::SparseMatrixCSC) = RS_CF_splitting(S - spdiagm(diag(S)))
9-
function RS_CF_splitting(S::SparseMatrixCSC)
25+
function RS_CF_splitting(S::SparseMatrixCSC, T::SparseMatrixCSC)
1026

1127
m,n = size(S)
1228

1329
n_nodes = n
1430
lambda = zeros(Int, n)
15-
T = S'
31+
1632
Tp = T.colptr
1733
Tj = T.rowval
1834
Sp = S.colptr
@@ -100,8 +116,6 @@ function RS_CF_splitting(S::SparseMatrixCSC)
100116

101117
lambda[row] == 0 && continue
102118

103-
# assert(lambda[j] > 0);//this would cause problems!
104-
105119
# move j to the beginning of its current interval
106120
lambda_j = lambda[row] + 1
107121
old_pos = node_to_index[row]
@@ -125,97 +139,3 @@ function RS_CF_splitting(S::SparseMatrixCSC)
125139
end
126140
splitting
127141
end
128-
129-
130-
# Now add elements to C and F, in descending order of lambda
131-
#=for top_index = n_nodes:-1:1
132-
i = index_to_node[top_index]
133-
lambda_i = lambda[i] + 1
134-
135-
# if (n_nodes == 4)
136-
# std::cout << "selecting node #" << i << " with lambda " << lambda[i] << std::endl;
137-
138-
# remove i from its interval
139-
interval_count[lambda_i] -= 1
140-
141-
if splitting[i] == F_NODE
142-
continue
143-
else
144-
145-
@assert splitting[i] == U_NODE
146-
147-
splitting[i] = C_NODE
148-
149-
# For each j in S^T_i /\ U
150-
for jj = Tp[i]:Tp[i+1]-1
151-
#jj > length(Tp) && continue
152-
153-
j = Tj[jj]
154-
155-
if splitting[j] == U_NODE
156-
splitting[j] = F_NODE
157-
158-
# For each k in S_j /\ U
159-
for kk = Sp[j]: Sp[j+1]-1
160-
# kk > length(Sj) && continue
161-
k = Sj[kk]
162-
163-
if splitting[k] == U_NODE
164-
# move k to the end of its current interval
165-
lambda[k] >= n_nodes - 1 && continue
166-
167-
lambda_k = lambda[k] + 1
168-
old_pos = node_to_index[k]
169-
new_pos = interval_ptr[lambda_k] + interval_count[lambda_k]# - 1
170-
171-
node_to_index[index_to_node[old_pos]] = new_pos
172-
node_to_index[index_to_node[new_pos]] = old_pos
173-
(index_to_node[old_pos], index_to_node[new_pos]) = (index_to_node[new_pos], index_to_node[old_pos])
174-
175-
# update intervals
176-
interval_count[lambda_k] -= 1
177-
interval_count[lambda_k+1] += 1 # invalid write!
178-
interval_ptr[lambda_k+1] = new_pos - 1
179-
180-
# increment lambda_k
181-
lambda[k] += 1
182-
end
183-
end
184-
end
185-
end
186-
187-
# For each j in S_i /\ U
188-
for jj = Sp[i]: Sp[i+1] - 1
189-
# jj > length(Sj) && continue
190-
191-
j = Sj[jj]
192-
193-
if splitting[j] == U_NODE # decrement lambda for node j
194-
195-
lambda[j] == 0 && continue
196-
197-
# assert(lambda[j] > 0);//this would cause problems!
198-
199-
# move j to the beginning of its current interval
200-
lambda_j = lambda[j] + 1
201-
old_pos = node_to_index[j]
202-
new_pos = interval_ptr[lambda_j]
203-
204-
node_to_index[index_to_node[old_pos]] = new_pos
205-
node_to_index[index_to_node[new_pos]] = old_pos
206-
(index_to_node[old_pos],index_to_node[new_pos]) = (index_to_node[new_pos],index_to_node[old_pos])
207-
208-
# update intervals
209-
interval_count[lambda_j] -= 1
210-
interval_count[lambda_j-1] += 1
211-
interval_ptr[lambda_j] += 1
212-
interval_ptr[lambda_j-1] = interval_ptr[lambda_j] - interval_count[lambda_j-1]
213-
214-
# decrement lambda_j
215-
lambda[j] -= 1
216-
end
217-
end
218-
end
219-
end
220-
splitting
221-
end=#

src/strength.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ function strength_of_connection{T}(c::Classical{T}, A::SparseMatrixCSC)
3939

4040
scale_cols_by_largest_entry!(S)
4141

42-
S'
42+
43+
S', S
4344
end
4445

4546
function find_max_off_diag(neighbors, col)

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ ref_split = readdlm("ref_split_test.txt")
1212

1313
# classical strength of connection
1414
A = poisson(5)
15-
S = strength_of_connection(Classical(0.2), A)
15+
S, T = strength_of_connection(Classical(0.2), A)
1616
@test full(S) == [ 1.0 0.5 0.0 0.0 0.0
1717
0.5 1.0 0.5 0.0 0.0
1818
0.0 0.5 1.0 0.5 0.0
1919
0.0 0.0 0.5 1.0 0.5
2020
0.0 0.0 0.0 0.5 1.0 ]
21-
S = strength_of_connection(Classical(0.25), graph)
21+
S, T = strength_of_connection(Classical(0.25), graph)
2222
diff = S - ref_S
2323
@test maximum(diff) < 1e-10
2424

@@ -34,7 +34,7 @@ S = sprand(10,10,0.1); S = S + S'
3434
@test split_nodes(RS(), S) == [0, 1, 1, 0, 0, 0, 0, 0, 1, 1]
3535

3636
a = load("thing.jld")["G"]
37-
S = AMG.strength_of_connection(AMG.Classical(0.25), a)
37+
S, T = AMG.strength_of_connection(AMG.Classical(0.25), a)
3838
@test split_nodes(RS(), S) == [0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0,
3939
0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0,
4040
1, 0]

0 commit comments

Comments
 (0)