Skip to content

Commit c467855

Browse files
committed
fix v0.6 perf
1 parent 76ac8f1 commit c467855

File tree

8 files changed

+69
-29
lines changed

8 files changed

+69
-29
lines changed

src/aggregation.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
function smoothed_aggregation(A::SparseMatrixCSC{T,V}, ::Type{Val{bs}}=Val{1},
1+
function smoothed_aggregation(A::TA,
2+
::Type{Val{bs}}=Val{1},
23
symmetry = HermitianSymmetry(),
34
strength = SymmetricStrength(),
45
aggregate = StandardAggregation(),
@@ -10,7 +11,7 @@ function smoothed_aggregation(A::SparseMatrixCSC{T,V}, ::Type{Val{bs}}=Val{1},
1011
max_coarse = 10,
1112
diagonal_dominance = false,
1213
keep = false,
13-
coarse_solver = Pinv) where {T,V,bs}
14+
coarse_solver = Pinv) where {T,V,bs,TA<:SparseMatrixCSC{T,V}}
1415

1516
n = size(A, 1)
1617
# B = kron(ones(n, 1), eye(1))
@@ -27,7 +28,11 @@ function smoothed_aggregation(A::SparseMatrixCSC{T,V}, ::Type{Val{bs}}=Val{1},
2728
# agg = [aggregate for _ in 1:max_levels - 1]
2829
# sm = [smooth for _ in 1:max_levels]
2930

30-
levels = Vector{Level{T,V}}()
31+
@static if VERSION < v"0.7-"
32+
levels = Vector{Level{TA, TA, TA}}()
33+
else
34+
levels = Vector{Level{TA, TA, Adjoint{T, TA}}}()
35+
end
3136
bsr_flag = false
3237
w = MultiLevelWorkspace(Val{bs}, eltype(A))
3338

@@ -57,7 +62,11 @@ function extend_hierarchy!(levels, strength, aggregate, smooth,
5762
symmetry, bsr_flag)
5863

5964
# Calculate strength of connection matrix
60-
S = strength(A, bsr_flag)
65+
if symmetry isa HermitianSymmetry
66+
S, _T = strength(A, bsr_flag)
67+
else
68+
S, _T = strength(adjoint(A), bsr_flag)
69+
end
6170

6271
# Aggregation operator
6372
AggOp = aggregate(S)
@@ -80,11 +89,11 @@ function extend_hierarchy!(levels, strength, aggregate, smooth,
8089

8190
A, B, bsr_flag
8291
end
83-
construct_R(::HermitianSymmetry, P) = copy(P')
92+
construct_R(::HermitianSymmetry, P) = P'
8493

8594
function fit_candidates(AggOp, B, tol = 1e-10)
8695

87-
A = copy(AggOp')
96+
A = adjoint(AggOp)
8897
n_fine, n_coarse = size(A)
8998
n_col = n_coarse
9099

src/classical.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,56 @@ struct Solver{S,T,P,PS}
77
max_coarse::Int64
88
end
99

10-
function ruge_stuben(A::SparseMatrixCSC{Ti,Tv}, ::Type{Val{bs}}=Val{1};
10+
function ruge_stuben(_A::Union{TA, Symmetric{Ti, TA}, Hermitian{Ti, TA}},
11+
::Type{Val{bs}}=Val{1};
1112
strength = Classical(0.25),
1213
CF = RS(),
1314
presmoother = GaussSeidel(),
1415
postsmoother = GaussSeidel(),
1516
max_levels = 10,
1617
max_coarse = 10,
17-
coarse_solver = Pinv) where {Ti,Tv,bs}
18+
coarse_solver = Pinv) where {Ti,Tv,bs,TA<:SparseMatrixCSC{Ti,Tv}}
1819

1920
s = Solver(strength, CF, presmoother,
2021
postsmoother, max_levels, max_levels)
2122

22-
levels = Vector{Level{Ti,Tv}}()
23+
if _A isa Symmetric && Ti <: Real || _A isa Hermitian
24+
A = _A.data
25+
At = A
26+
symmetric = true
27+
@static if VERSION < v"0.7-"
28+
levels = Vector{Level{TA, TA}}()
29+
else
30+
levels = Vector{Level{TA, Adjoint{Ti, TA}, TA}}()
31+
end
32+
else
33+
symmetric = false
34+
A = _A
35+
At = adjoint(A)
36+
@static if VERSION < v"0.7-"
37+
levels = Vector{Level{TA, TA, TA}}()
38+
else
39+
levels = Vector{Level{TA, Adjoint{Ti, TA}, TA}}()
40+
end
41+
end
2342
w = MultiLevelWorkspace(Val{bs}, eltype(A))
2443

2544
while length(levels) + 1 < max_levels && size(A, 1) > max_coarse
2645
residual!(w, size(A, 1))
27-
A = extend_heirarchy!(levels, strength, CF, A)
46+
A = extend_heirarchy!(levels, strength, CF, A, symmetric)
2847
coarse_x!(w, size(A, 1))
2948
coarse_b!(w, size(A, 1))
3049
end
3150

3251
MultiLevel(levels, A, coarse_solver(A), presmoother, postsmoother, w)
3352
end
3453

35-
function extend_heirarchy!(levels::Vector{Level{Ti,Tv}}, strength, CF, A::SparseMatrixCSC{Ti,Tv}) where {Ti,Tv}
36-
At = copy(A')
54+
function extend_heirarchy!(levels, strength, CF, A::SparseMatrixCSC{Ti,Tv}, symmetric) where {Ti,Tv}
55+
if symmetric
56+
At = A
57+
else
58+
At = adjoint(A)
59+
end
3760
S, T = strength(At)
3861
splitting = CF(S)
3962
P, R = direct_interpolation(At, T, splitting)
@@ -48,7 +71,7 @@ function direct_interpolation(At, T, splitting)
4871
Pp = rs_direct_interpolation_pass1(T, splitting)
4972
Px, Pj, Pp = rs_direct_interpolation_pass2(At, T, splitting, Pp)
5073
R = SparseMatrixCSC(maximum(Pj), size(At, 1), Pp, Pj, Px)
51-
P = copy(R')
74+
P = R'
5275

5376
P, R
5477
end

src/multilevel.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
struct Level{T,V}
2-
A::SparseMatrixCSC{T,V}
3-
P::SparseMatrixCSC{T,V}
4-
R::SparseMatrixCSC{T,V}
1+
struct Level{TA, TP, TR}
2+
A::TA
3+
P::TP
4+
R::TR
55
end
66

7-
struct MultiLevel{S, Pre, Post, Ti, Tv, TW}
8-
levels::Vector{Level{Ti,Tv}}
9-
final_A::SparseMatrixCSC{Ti,Tv}
7+
struct MultiLevel{S, Pre, Post, TA, TP, TR, TW}
8+
levels::Vector{Level{TA, TP, TR}}
9+
final_A::TA
1010
coarse_solver::S
1111
presmoother::Pre
1212
postsmoother::Post

src/splitting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919

2020
function (::RS)(S)
2121
remove_diag!(S)
22-
RS_CF_splitting(S, copy(S'))
22+
RS_CF_splitting(S, adjoint(S))
2323
end
2424

2525
function RS_CF_splitting(S::SparseMatrixCSC, T::SparseMatrixCSC)

src/strength.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function (c::Classical)(At::SparseMatrixCSC{Tv,Ti}) where {Ti,Tv}
3333

3434
scale_cols_by_largest_entry!(T)
3535

36-
copy(T'), T
36+
adjoint(T), T
3737
end
3838

3939
function find_max_off_diag(A, i)
@@ -81,7 +81,7 @@ function (s::SymmetricStrength{T})(A, bsr_flag = false) where {T}
8181
if bsr_flag && θ == 0
8282
S = SparseMatrixCSC(size(A)...,
8383
A.colptr, A.rowval, ones(eltype(A), size(A.rowval)))
84-
return S
84+
return S, S
8585
else
8686
S = deepcopy(A)
8787
end
@@ -118,5 +118,5 @@ function (s::SymmetricStrength{T})(A, bsr_flag = false) where {T}
118118
S.nzval .= abs.(S.nzval)
119119
scale_cols_by_largest_entry!(S)
120120

121-
S
121+
S, S
122122
end

src/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
function adjoint(A)
2+
@static if VERSION < v"0.7-"
3+
A'
4+
else
5+
copy(A')
6+
end
7+
end
8+
19
function approximate_spectral_radius(A, tol = 0.01,
210
maxiter = 15, restart = 5)
311

test/gmg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import AlgebraicMultigrid: Level, MultiLevel, GaussSeidel
22

3-
function multigrid(A::SparseMatrixCSC{T,V}; max_levels = 10, max_coarse = 10,
4-
presmoother = GaussSeidel(), postsmoother = GaussSeidel()) where {T,V}
3+
function multigrid(A::TA; max_levels = 10, max_coarse = 10,
4+
presmoother = GaussSeidel(), postsmoother = GaussSeidel()) where {T,V,TA<:SparseMatrixCSC{T,V}}
55

6-
levels = Vector{Level{T,V}}()
6+
levels = Vector{Level{TA,TA,TA}}()
77
w = AlgebraicMultigrid.MultiLevelWorkspace(Val{1}, eltype(A))
88

99
while length(levels) + 1 < max_levels && size(A, 1) > max_coarse
@@ -43,7 +43,7 @@ function extend!(levels, A::SparseMatrixCSC{Ti,Tv}) where {Ti,Tv}
4343

4444
P = sparse(I, J, V, size_F, size_C)
4545

46-
R = copy(P')
46+
R = AlgebraicMultigrid.adjoint(P)
4747

4848
push!(levels, Level(A, P, R))
4949

test/sa_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function test_symmetric_soc()
3030
for matrix in cases
3131
for θ in (0.0, 0.1, 0.5, 1., 10.)
3232
ref_matrix = symmetric_soc(matrix, θ)
33-
calc_matrix = SymmetricStrength(θ)(matrix)
33+
calc_matrix, _ = SymmetricStrength(θ)(matrix)
3434

3535
@test sum(abs2, ref_matrix - calc_matrix) < 1e-6
3636
end

0 commit comments

Comments
 (0)