Skip to content

Commit 5d83e64

Browse files
committed
Rewrite the RS splitting algorithm
1 parent 4093be9 commit 5d83e64

File tree

4 files changed

+113
-28
lines changed

4 files changed

+113
-28
lines changed

src/classical.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ function direct_interpolation{T,V}(A::T, S::T, splitting::Vector{V})
4242
fill!(S.nzval, 1.)
4343
S = A .* S
4444
Pp = rs_direct_interpolation_pass1(S, A, splitting)
45-
Pp .= Pp .+ 1
45+
Pp = Pp .+ 1
4646

4747
Px, Pj = rs_direct_interpolation_pass2(A, S, splitting, Pp)
4848

4949
# Px .= abs.(Px)
50-
Pj .= Pj .+ 1
50+
Pj = Pj .+ 1
5151

5252
R = SparseMatrixCSC(maximum(Pj), size(A, 1), Pp, Pj, Px)
5353
P = R'
@@ -59,7 +59,7 @@ end
5959
function rs_direct_interpolation_pass1(S, A, splitting)
6060

6161
Bp = zeros(Int, size(A.colptr))
62-
Sp = S.colptr
62+
#=Sp = S.colptr
6363
Sj = S.rowval
6464
n_nodes = size(A, 1)
6565
nnz = 0
@@ -75,6 +75,21 @@ function rs_direct_interpolation_pass1(S, A, splitting)
7575
end
7676
end
7777
Bp[i+1] = nnz
78+
end=#
79+
n = size(A, 1)
80+
nnz = 0
81+
for i = 1:n
82+
if splitting[i] == C_NODE
83+
nnz += 1
84+
else
85+
for j in nzrange(S, i)
86+
row = S.rowval[j]
87+
if splitting[row] == C_NODE && row != i
88+
nnz += 1
89+
end
90+
end
91+
end
92+
Bp[i+1] = nnz
7893
end
7994
Bp
8095
end

src/splitting.jl

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ function RS_CF_splitting(S::SparseMatrixCSC)
1212

1313
n_nodes = n
1414
lambda = zeros(Int, n)
15-
Tp = S.colptr
16-
Tj = S.rowval
17-
Sp = Tp
18-
Sj = Tj
15+
T = S'
16+
Tp = T.colptr
17+
Tj = T.rowval
18+
Sp = S.colptr
19+
Sj = S.rowval
1920

20-
# compute lambdas
21+
# compute lambdas - number of neighbors
2122
for i = 1:n
2223
lambda[i] = Tp[i+1] - Tp[i]
2324
end
@@ -27,6 +28,7 @@ function RS_CF_splitting(S::SparseMatrixCSC)
2728
index_to_node = zeros(Int,n)
2829
node_to_index = zeros(Int,n)
2930

31+
# Number of nodes with a certain neighbor count
3032
for i = 1:n
3133
interval_count[lambda[i]+1] += 1
3234
end
@@ -52,9 +54,83 @@ function RS_CF_splitting(S::SparseMatrixCSC)
5254
splitting[i] = F_NODE
5355
end
5456
end
57+
@show lambda
58+
@show node_to_index
59+
@show index_to_node
60+
@show interval_ptr
61+
62+
for top_index = n_nodes:-1:1
63+
i = index_to_node[top_index]
64+
lambda_i = lambda[i] + 1
65+
interval_count[lambda_i] -= 1
66+
if splitting[i] == F_NODE
67+
continue
68+
else
69+
@assert splitting[i] == U_NODE
70+
splitting[i] = C_NODE
71+
for j in nzrange(T, i)
72+
row = T.rowval[j]
73+
if splitting[row] == U_NODE
74+
splitting[row] = F_NODE
75+
76+
for k in nzrange(S, row)
77+
rowk = S.rowval[k]
78+
if splitting[rowk] == U_NODE
79+
lambda[rowk] >= n_nodes - 1 && continue
80+
lambda_k = lambda[rowk] + 1
81+
old_pos = node_to_index[rowk]
82+
new_pos = interval_ptr[lambda_k] + interval_count[lambda_k]# - 1
83+
84+
node_to_index[index_to_node[old_pos]] = new_pos
85+
node_to_index[index_to_node[new_pos]] = old_pos
86+
(index_to_node[old_pos], index_to_node[new_pos]) = (index_to_node[new_pos], index_to_node[old_pos])
87+
88+
# update intervals
89+
interval_count[lambda_k] -= 1
90+
interval_count[lambda_k + 1] += 1 # invalid write!
91+
interval_ptr[lambda_k + 1] = new_pos - 1
92+
93+
# increment lambda_k
94+
lambda[rowk] += 1
95+
end
96+
end
97+
end
98+
end
99+
for j in nzrange(S, i)
100+
row = S.rowval[j]
101+
if splitting[row] == U_NODE
102+
103+
lambda[row] == 0 && continue
104+
105+
# assert(lambda[j] > 0);//this would cause problems!
106+
107+
# move j to the beginning of its current interval
108+
lambda_j = lambda[row] + 1
109+
old_pos = node_to_index[row]
110+
new_pos = interval_ptr[lambda_j]
111+
112+
node_to_index[index_to_node[old_pos]] = new_pos
113+
node_to_index[index_to_node[new_pos]] = old_pos
114+
(index_to_node[old_pos],index_to_node[new_pos]) = (index_to_node[new_pos],index_to_node[old_pos])
115+
116+
# update intervals
117+
interval_count[lambda_j] -= 1
118+
interval_count[lambda_j-1] += 1
119+
interval_ptr[lambda_j] += 1
120+
interval_ptr[lambda_j-1] = interval_ptr[lambda_j] - interval_count[lambda_j-1]
121+
122+
# decrement lambda_j
123+
lambda[row] -= 1
124+
end
125+
end
126+
end
127+
end
128+
splitting
129+
end
130+
55131

56132
# Now add elements to C and F, in descending order of lambda
57-
for top_index = n_nodes:-1:1
133+
#=for top_index = n_nodes:-1:1
58134
i = index_to_node[top_index]
59135
lambda_i = lambda[i] + 1
60136
@@ -74,6 +150,7 @@ function RS_CF_splitting(S::SparseMatrixCSC)
74150
75151
# For each j in S^T_i /\ U
76152
for jj = Tp[i]:Tp[i+1]-1
153+
#jj > length(Tp) && continue
77154
78155
j = Tj[jj]
79156
@@ -82,6 +159,7 @@ function RS_CF_splitting(S::SparseMatrixCSC)
82159
83160
# For each k in S_j /\ U
84161
for kk = Sp[j]: Sp[j+1]-1
162+
# kk > length(Sj) && continue
85163
k = Sj[kk]
86164
87165
if splitting[k] == U_NODE
@@ -109,7 +187,8 @@ function RS_CF_splitting(S::SparseMatrixCSC)
109187
end
110188
111189
# For each j in S_i /\ U
112-
for jj = Sp[i]: Sp[i+1]-1
190+
for jj = Sp[i]: Sp[i+1] - 1
191+
# jj > length(Sj) && continue
113192
114193
j = Sj[jj]
115194
@@ -141,4 +220,4 @@ function RS_CF_splitting(S::SparseMatrixCSC)
141220
end
142221
end
143222
splitting
144-
end
223+
end=#

src/strength.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ function strength_of_connection{T}(c::Classical{T}, A::SparseMatrixCSC)
2929
end
3030
S = sparse(I, J, V, m, n)
3131

32-
scale_cols_by_largest_entry(S)
32+
scale_cols_by_largest_entry!(S)
33+
34+
S'
3335
end
3436

3537
function find_max_off_diag(neighbors, col)
@@ -40,26 +42,15 @@ function find_max_off_diag(neighbors, col)
4042
return maxval
4143
end
4244

43-
function scale_cols_by_largest_entry(A::SparseMatrixCSC)
44-
45-
m,n = size(A)
46-
47-
I = zeros(Int, size(A.nzval))
48-
J = similar(I)
49-
V = zeros(size(A.nzval))
45+
function scale_cols_by_largest_entry!(A::SparseMatrixCSC)
5046

51-
k = 1
47+
n = size(A, 1)
5248
for i = 1:n
5349
_m = maximum(A[:,i])
5450
for j in nzrange(A, i)
55-
row = A.rowval[j]
56-
val = A.nzval[j]
57-
I[k] = row
58-
J[k] = i
59-
V[k] = val / _m
60-
k += 1
51+
A.nzval[j] /= _m
6152
end
6253
end
6354

64-
sparse(I,J,V,m,n)
55+
A
6556
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919
# Ruge-Stuben splitting
2020
S = poisson(7)
2121
@test split_nodes(RS(), S) == [0, 1, 0, 1, 0, 1, 0]
22-
22+
@show "buzz"
2323
srand(0)
2424
S = sprand(10,10,0.1); S = S + S'
2525
@test split_nodes(RS(), S) == [0, 1, 1, 0, 0, 0, 0, 0, 1, 1]

0 commit comments

Comments
 (0)