Skip to content

Commit e89a333

Browse files
committed
use callable structs
1 parent 8bc5deb commit e89a333

File tree

11 files changed

+45
-65
lines changed

11 files changed

+45
-65
lines changed

src/AlgebraicMultigrid.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@ include("utils.jl")
1414
export approximate_spectral_radius
1515

1616
include("strength.jl")
17-
export strength_of_connection, Classical, SymmetricStrength
17+
export Classical, SymmetricStrength
1818

1919
include("splitting.jl")
20-
export split_nodes, RS
20+
export RS
2121

2222
include("gallery.jl")
2323
export poisson
2424

2525
include("smoother.jl")
2626
export GaussSeidel, SymmetricSweep, ForwardSweep, BackwardSweep,
27-
smooth_prolongator, JacobiProlongation
27+
JacobiProlongation
2828

2929
include("multilevel.jl")
3030
export solve
@@ -33,7 +33,7 @@ include("classical.jl")
3333
export ruge_stuben
3434

3535
include("aggregate.jl")
36-
export aggregation, StandardAggregation
36+
export StandardAggregation
3737

3838
include("aggregation.jl")
3939
export fit_candidates, smoothed_aggregation

src/aggregate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
struct StandardAggregation
22
end
33

4-
function aggregation(::StandardAggregation, S::SparseMatrixCSC{T,R}) where {T,R}
4+
function (::StandardAggregation)(S::SparseMatrixCSC{T,R}) where {T,R}
55

66
n = size(S, 1)
77
x = zeros(R, n)

src/aggregation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,18 @@ function extend_hierarchy!(levels, strength, aggregate, smooth,
5454
symmetry, bsr_flag)
5555

5656
# Calculate strength of connection matrix
57-
S = strength_of_connection(strength, A, bsr_flag)
57+
S = strength(A, bsr_flag)
5858

5959
# Aggregation operator
60-
AggOp = aggregation(aggregate, S)
60+
AggOp = aggregate(S)
6161
# b = zeros(eltype(A), size(A, 1))
6262

6363
# Improve candidates
6464
b = zeros(size(A,1))
65-
relax!(improve_candidates, A, B, b)
65+
improve_candidates(A, B, b)
6666
T, B = fit_candidates(AggOp, B)
6767

68-
P = smooth_prolongator(smooth, A, T, S, B)
68+
P = smooth(A, T, S, B)
6969
R = construct_R(symmetry, P)
7070
push!(levels, Level(A, P, R))
7171

src/classical.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ end
2828

2929
function extend_heirarchy!(levels::Vector{Level{Ti,Tv}}, strength, CF, A::SparseMatrixCSC{Ti,Tv}) where {Ti,Tv}
3030
At = copy(A')
31-
S, T = strength_of_connection(strength, At)
32-
splitting = split_nodes(CF, S)
31+
S, T = strength(At)
32+
splitting = CF(S)
3333
P, R = direct_interpolation(At, T, splitting)
3434
push!(levels, Level(A, P, R))
3535
A = R * A * P

src/multilevel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ end
130130
function __solve(v::V, ml, x, b, lvl)
131131

132132
A = ml.levels[lvl].A
133-
presmoother!(ml.presmoother, A, x, b)
133+
ml.presmoother(A, x, b)
134134

135135
res = b - A * x
136136
coarse_b = ml.levels[lvl].R * res
@@ -144,7 +144,7 @@ function __solve(v::V, ml, x, b, lvl)
144144

145145
x .+= ml.levels[lvl].P * coarse_x
146146

147-
postsmoother!(ml.postsmoother, A, x, b)
147+
ml.postsmoother(A, x, b)
148148

149149
x
150150
end

src/preconditioner.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import Compat.LinearAlgebra: \, *, ldiv!, mul!
2-
31
struct Preconditioner
42
ml::MultiLevel
53
end

src/smoother.jl

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,17 @@ GaussSeidel(f::ForwardSweep) = GaussSeidel(f, 1)
1515
GaussSeidel(b::BackwardSweep) = GaussSeidel(b, 1)
1616
GaussSeidel(s::SymmetricSweep) = GaussSeidel(s, 1)
1717

18-
presmoother!(s, A, x, b) = smoother!(s, s.sweep, A, x, b)
19-
postsmoother!(s, A, x, b) = smoother!(s, s.sweep, A, x, b)
20-
relax!(s, A, x, b) = smoother!(s, s.sweep, A, x, b)
21-
22-
function smoother!(s::GaussSeidel, ::ForwardSweep, A, x, b)
23-
for i in 1:s.iter
24-
gs!(A, b, x, 1, 1, size(A, 1))
25-
end
26-
end
27-
28-
function smoother!(s::GaussSeidel, ::SymmetricSweep, A, x, b)
29-
for i in 1:s.iter
30-
gs!(A, b, x, 1, 1, size(A, 1))
31-
gs!(A, b, x, size(A,1), -1, 1)
32-
end
33-
end
34-
35-
function smoother!(s::GaussSeidel, ::BackwardSweep, A, x, b)
18+
function (s::GaussSeidel{S})(A, x, b) where {S<:Sweep}
3619
for i in 1:s.iter
37-
gs!(A, b, x, size(A,1), -1, 1)
20+
if S === ForwardSweep || S === SymmetricSweep
21+
gs!(A, b, x, 1, 1, size(A, 1))
22+
end
23+
if S === BackwardSweep || S === SymmetricSweep
24+
gs!(A, b, x, size(A, 1), -1, 1)
25+
end
3826
end
3927
end
4028

41-
4229
function gs!(A, b, x, start, step, stop)
4330
n = size(A, 1)
4431
z = zero(eltype(A))
@@ -104,10 +91,7 @@ end
10491
struct LocalWeighting
10592
end
10693

107-
function smooth_prolongator(j::JacobiProlongation,
108-
A, T, S, B,
109-
degree = 1,
110-
weighting = LocalWeighting())
94+
function (j::JacobiProlongation)(A, T, S, B, degree = 1, weighting = LocalWeighting())
11195
D_inv_S = weight(weighting, A, j.ω)
11296
P = T
11397
for i = 1:degree

src/splitting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function remove_diag!(a)
1717
dropzeros!(a)
1818
end
1919

20-
function split_nodes(::RS, S)
20+
function (::RS)(S)
2121
remove_diag!(S)
2222
RS_CF_splitting(S, copy(S'))
2323
end

src/strength.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ struct Classical{T} <: Strength
44
end
55
Classical(;θ = 0.25) = Classical(θ)
66

7-
function strength_of_connection(c::Classical,
8-
At::SparseMatrixCSC{Tv,Ti}) where {Ti,Tv}
7+
function (c::Classical)(At::SparseMatrixCSC{Tv,Ti}) where {Ti,Tv}
98

109
θ = c.θ
1110

@@ -75,7 +74,7 @@ struct SymmetricStrength{T} <: Strength
7574
end
7675
SymmetricStrength() = SymmetricStrength(0.)
7776

78-
function strength_of_connection(s::SymmetricStrength{T}, A, bsr_flag = false) where {T}
77+
function (s::SymmetricStrength{T})(A, bsr_flag = false) where {T}
7978

8079
θ = s.θ
8180

test/runtests.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ ref_split = readdlm("ref_split_test.txt")
1616
# classical strength of connection
1717
A = poisson(5)
1818
A = float.(A)
19-
S, T = strength_of_connection(Classical(0.2), A)
19+
S, T = Classical(0.2)(A)
2020
@test Matrix(S) == [ 1.0 0.5 0.0 0.0 0.0
2121
0.5 1.0 0.5 0.0 0.0
2222
0.0 0.5 1.0 0.5 0.0
2323
0.0 0.0 0.5 1.0 0.5
2424
0.0 0.0 0.0 0.5 1.0 ]
25-
S, T = strength_of_connection(Classical(0.25), graph)
25+
S, T = Classical(0.25)(graph)
2626
diff = S - ref_S
2727
@test maximum(diff) < 1e-10
2828

@@ -32,18 +32,18 @@ end
3232

3333
# Ruge-Stuben splitting
3434
S = poisson(7)
35-
@test split_nodes(RS(), S) == [0, 1, 0, 1, 0, 1, 0]
35+
@test RS()(S) == [0, 1, 0, 1, 0, 1, 0]
3636
srand(0)
3737
S = sprand(10,10,0.1); S = S + S'
38-
@test split_nodes(RS(), S) == [0, 1, 1, 0, 0, 0, 0, 0, 1, 1]
38+
@test RS()(S) == [0, 1, 1, 0, 0, 0, 0, 0, 1, 1]
3939

4040
a = load("thing.jld2")["G"]
41-
S, T = AlgebraicMultigrid.strength_of_connection(Classical(0.25), a)
42-
@test split_nodes(RS(), S) == [0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0,
41+
S, T = Classical(0.25)(a)
42+
@test RS()(S) == [0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0,
4343
0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0,
4444
1, 0]
4545

46-
@test split_nodes(RS(), ref_S) == Int.(vec(ref_split))
46+
@test RS()(ref_S) == Int.(vec(ref_split))
4747

4848
end
4949

0 commit comments

Comments
 (0)