Skip to content

Commit 8bc5deb

Browse files
committed
Generalize direct interpol for non-symmetric matrices and edge cases
1 parent 3913da1 commit 8bc5deb

File tree

5 files changed

+55
-44
lines changed

5 files changed

+55
-44
lines changed

src/classical.jl

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,26 @@ function ruge_stuben(A::SparseMatrixCSC{Ti,Tv};
2222

2323
while length(levels) + 1 < max_levels && size(A, 1) > max_coarse
2424
A = extend_heirarchy!(levels, strength, CF, A)
25-
#if size(A, 1) <= max_coarse
26-
# break
27-
#end
2825
end
2926
MultiLevel(levels, A, presmoother, postsmoother)
3027
end
3128

3229
function extend_heirarchy!(levels::Vector{Level{Ti,Tv}}, strength, CF, A::SparseMatrixCSC{Ti,Tv}) where {Ti,Tv}
33-
S, T = strength_of_connection(strength, A)
30+
At = copy(A')
31+
S, T = strength_of_connection(strength, At)
3432
splitting = split_nodes(CF, S)
35-
P, R = direct_interpolation(A, T, splitting)
33+
P, R = direct_interpolation(At, T, splitting)
3634
push!(levels, Level(A, P, R))
3735
A = R * A * P
3836
end
3937

40-
function direct_interpolation(A, T, splitting)
41-
fill!(T.nzval, eltype(A)(1))
42-
T .= A .* T
38+
function direct_interpolation(At, T, splitting)
39+
fill!(T.nzval, eltype(At)(1))
40+
T .= At .* T
41+
4342
Pp = rs_direct_interpolation_pass1(T, splitting)
44-
Px, Pj, Pp = rs_direct_interpolation_pass2(A, T, splitting, Pp)
45-
46-
R = SparseMatrixCSC(maximum(Pj), size(A, 1), Pp, Pj, Px)
43+
Px, Pj, Pp = rs_direct_interpolation_pass2(At, T, splitting, Pp)
44+
R = SparseMatrixCSC(maximum(Pj), size(At, 1), Pp, Pj, Px)
4745
P = copy(R')
4846

4947
P, R
@@ -60,7 +58,7 @@ function rs_direct_interpolation_pass1(T, splitting)
6058
else
6159
for j in nzrange(T, i)
6260
row = T.rowval[j]
63-
if splitting[row] == C_NODE && row != i
61+
if splitting[row] == C_NODE
6462
nnzplus1 += 1
6563
end
6664
end
@@ -71,15 +69,15 @@ function rs_direct_interpolation_pass1(T, splitting)
7169
end
7270

7371

74-
function rs_direct_interpolation_pass2(A::SparseMatrixCSC{Tv,Ti},
72+
function rs_direct_interpolation_pass2(At::SparseMatrixCSC{Tv,Ti},
7573
T::SparseMatrixCSC{Tv, Ti},
7674
splitting::Vector{Ti},
7775
Bp::Vector{Ti}) where {Tv,Ti}
7876

7977
Bx = zeros(Tv, Bp[end] - 1)
8078
Bj = zeros(Ti, Bp[end] - 1)
8179

82-
n = size(A, 1)
80+
n = size(At, 1)
8381

8482
for i = 1:n
8583
if splitting[i] == C_NODE
@@ -91,7 +89,7 @@ function rs_direct_interpolation_pass2(A::SparseMatrixCSC{Tv,Ti},
9189
for j in nzrange(T, i)
9290
row = T.rowval[j]
9391
sval = T.nzval[j]
94-
if splitting[row] == C_NODE && row != i
92+
if splitting[row] == C_NODE
9593
if sval < 0
9694
sum_strong_neg += sval
9795
else
@@ -102,9 +100,9 @@ function rs_direct_interpolation_pass2(A::SparseMatrixCSC{Tv,Ti},
102100
sum_all_pos = zero(Tv)
103101
sum_all_neg = zero(Tv)
104102
diag = zero(Tv)
105-
for j in nzrange(A, i)
106-
row = A.rowval[j]
107-
aval = A.nzval[j]
103+
for j in nzrange(At, i)
104+
row = At.rowval[j]
105+
aval = At.nzval[j]
108106
if row == i
109107
diag += aval
110108
else
@@ -115,28 +113,43 @@ function rs_direct_interpolation_pass2(A::SparseMatrixCSC{Tv,Ti},
115113
end
116114
end
117115
end
118-
alpha = sum_all_neg / sum_strong_neg
119-
beta = sum_all_pos / sum_strong_pos
120116

121117
if sum_strong_pos == 0
122-
diag += sum_all_pos
123-
beta = zero(beta)
118+
beta = zero(diag)
119+
if diag >= 0
120+
diag += sum_all_pos
121+
end
122+
else
123+
beta = sum_all_pos / sum_strong_pos
124124
end
125125

126-
neg_coeff = -1 * alpha / diag
127-
pos_coeff = -1 * beta / diag
126+
if sum_strong_neg == 0
127+
alpha = zero(diag)
128+
if diag < 0
129+
diag += sum_all_neg
130+
end
131+
else
132+
alpha = sum_all_neg / sum_strong_neg
133+
end
128134

129-
nnz = Bp[i]
135+
if isapprox(diag, 0, atol=eps(Tv))
136+
neg_coeff = Tv(0)
137+
pos_coeff = Tv(0)
138+
else
139+
neg_coeff = alpha / diag
140+
pos_coeff = beta / diag
141+
end
130142

143+
nnz = Bp[i]
131144
for j in nzrange(T, i)
132145
row = T.rowval[j]
133146
sval = T.nzval[j]
134-
if splitting[row] == C_NODE && row != i
147+
if splitting[row] == C_NODE
135148
Bj[nnz] = row
136149
if sval < 0
137-
Bx[nnz] = neg_coeff * sval
150+
Bx[nnz] = abs(neg_coeff * sval)
138151
else
139-
Bx[nnz] = pos_coeff * sval
152+
Bx[nnz] = abs(pos_coeff * sval)
140153
end
141154
nnz += 1
142155
end

src/multilevel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ struct V <: Cycle
7474
end
7575

7676
"""
77-
solve(ml::MultiLevel, b::Vector, cycle, kwargs...)
77+
solve(ml::MultiLevel, b::AbstractVector, cycle, kwargs...)
7878
7979
Execute multigrid cycling.
8080
@@ -92,7 +92,7 @@ Keyword Arguments
9292
* log::Bool - return vector of residuals along with solution
9393
9494
"""
95-
function solve(ml::MultiLevel, b::Vector{T},
95+
function solve(ml::MultiLevel, b::AbstractVector{T},
9696
cycle::Cycle = V();
9797
maxiter::Int = 100,
9898
tol::Float64 = 1e-5,

src/preconditioner.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import Compat.LinearAlgebra: \, *, ldiv!, mul!
2+
13
struct Preconditioner
24
ml::MultiLevel
35
end
@@ -10,6 +12,7 @@ aspreconditioner(ml::MultiLevel) = Preconditioner(ml)
1012
A_mul_B!(b, p::Preconditioner, x) = A_mul_B!(b, p.ml.levels[1].A, x)
1113
else
1214
import Compat.LinearAlgebra: \, *, ldiv!, mul!
15+
ldiv!(p::Preconditioner, b) = copyto!(b, p \ b)
1316
ldiv!(x, p::Preconditioner, b) = copyto!(x, p \ b)
1417
mul!(b, p::Preconditioner, x) = mul!(b, p.ml.levels[1].A, x)
1518
end

src/strength.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ end
55
Classical(;θ = 0.25) = Classical(θ)
66

77
function strength_of_connection(c::Classical,
8-
A::SparseMatrixCSC{Tv,Ti}) where {Ti,Tv}
8+
At::SparseMatrixCSC{Tv,Ti}) where {Ti,Tv}
99

1010
θ = c.θ
1111

12-
m, n = size(A)
13-
T = copy(A')
12+
m, n = size(At)
13+
T = deepcopy(At)
1414

1515
for i = 1:n
1616
_m = find_max_off_diag(T, i)
@@ -29,7 +29,7 @@ function strength_of_connection(c::Classical,
2929

3030
end
3131
end
32-
32+
3333
dropzeros!(T)
3434

3535
scale_cols_by_largest_entry!(T)

test/sa_tests.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ function symmetric_soc(A::SparseMatrixCSC{T,V}, θ) where {T,V}
1313

1414
S = sparse(i,j,v, size(A)...) + spdiagm(0=>D)
1515

16-
scale_cols_by_largest_entry!(S)
17-
1816
for i = 1:size(S.nzval,1)
1917
S.nzval[i] = abs(S.nzval[i])
2018
end
2119

20+
scale_cols_by_largest_entry!(S)
21+
2222
S
2323
end
2424

@@ -222,7 +222,7 @@ function test_approximate_spectral_radius()
222222
for A in cases
223223
A = A + A'
224224
@static if VERSION < v"0.7-"
225-
E,V = eig(A)
225+
E,V = eig(A)
226226
else
227227
E,V = (eigen(A)...,)
228228
end
@@ -231,15 +231,12 @@ function test_approximate_spectral_radius()
231231
expected_eig = E[largest_eig]
232232

233233
@test isapprox(approximate_spectral_radius(A), expected_eig)
234-
235234
end
236-
237235
end
238236

239237
# Test Gauss Seidel
240238
import AlgebraicMultigrid: gs!, relax!
241-
function test_gauss_seidel()
242-
239+
function test_gauss_seidel()
243240
N = 1
244241
A = spdiagm(0 => 2 * ones(N), -1 => -ones(N-1), 1 => -ones(N-1))
245242
x = eltype(A).(collect(0:N-1))
@@ -331,7 +328,5 @@ function test_symmetric_sweep()
331328
relax!(s, A, x, b)
332329
@test sum(abs2, x - [0.176765; 0.353529; 0.497517; 0.598914;
333330
0.653311; 0.659104; 0.615597; 0.52275;
334-
0.382787; 0.203251]) < 1e-6
335-
331+
0.382787; 0.203251]) < 1e-6
336332
end
337-

0 commit comments

Comments
 (0)