Skip to content

Commit dd7b7fb

Browse files
authored
Merge pull request #87 from JuliaLinearAlgebra/RA/commonsolve
Add commonsolve interface
2 parents b40945a + 1a43091 commit dd7b7fb

File tree

9 files changed

+69
-26
lines changed

9 files changed

+69
-26
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ uuid = "2169fc97-5a83-5252-b627-83903c6c433c"
33
version = "0.4.2"
44

55
[deps]
6+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
67
CompatHelper = "aa819f21-2bde-4658-8897-bab36330d9b7"
78
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
89
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1214
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1315
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1416

README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@ This package lets you solve sparse linear systems using Algebraic Multigrid (AMG
99

1010
## Usage
1111

12+
### Using the CommonSolve interface
13+
14+
This is highest level API. It internally creates the multilevel object
15+
and calls the multigrid cycling `_solve`.
16+
17+
```julia
18+
A = poisson(100);
19+
b = rand(100);
20+
solve(A, b, RugeStubenAMG(), maxiter = 1, abstol = 1e-6)
21+
```
22+
23+
### Multigrid cycling
24+
1225
```julia
1326
using AlgebraicMultigrid
1427

@@ -32,7 +45,7 @@ ml = ruge_stuben(A) # Construct a Ruge-Stuben solver
3245
# 8 7 19 [ 0.32%]
3346

3447

35-
solve(ml, A * ones(1000)) # should return ones(1000)
48+
AlgebraicMultigrid._solve(ml, A * ones(1000)) # should return ones(1000)
3649
```
3750

3851
### As a Preconditioner

src/AlgebraicMultigrid.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
module AlgebraicMultigrid
22

3+
using Reexport
34
import IterativeSolvers: gauss_seidel!
45
using LinearAlgebra
56
using SparseArrays, Printf
67
using Base.Threads
8+
@reexport import CommonSolve: solve, solve!, init
9+
using Reexport
710

811
using LinearAlgebra: rmul!
912

@@ -29,7 +32,7 @@ export GaussSeidel, SymmetricSweep, ForwardSweep, BackwardSweep,
2932
JacobiProlongation
3033

3134
include("multilevel.jl")
32-
export solve
35+
export RugeStubenAMG, SmoothedAggregationAMG
3336

3437
include("classical.jl")
3538
export ruge_stuben

src/aggregation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function smoothed_aggregation(A::TA,
1111
max_coarse = 10,
1212
diagonal_dominance = false,
1313
keep = false,
14-
coarse_solver = Pinv) where {T,V,bs,TA<:SparseMatrixCSC{T,V}}
14+
coarse_solver = Pinv, kwargs...) where {T,V,bs,TA<:SparseMatrixCSC{T,V}}
1515

1616
n = size(A, 1)
1717
# B = kron(ones(n, 1), eye(1))

src/classical.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function ruge_stuben(_A::Union{TA, Symmetric{Ti, TA}, Hermitian{Ti, TA}},
1515
postsmoother = GaussSeidel(),
1616
max_levels = 10,
1717
max_coarse = 10,
18-
coarse_solver = Pinv) where {Ti,Tv,bs,TA<:SparseMatrixCSC{Ti,Tv}}
18+
coarse_solver = Pinv, kwargs...) where {Ti,Tv,bs,TA<:SparseMatrixCSC{Ti,Tv}}
1919

2020
s = Solver(strength, CF, presmoother,
2121
postsmoother, max_levels, max_levels)

src/multilevel.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ struct F <: Cycle
126126
end
127127

128128
"""
129-
solve(ml::MultiLevel, b::AbstractArray, cycle, kwargs...)
129+
_solve(ml::MultiLevel, b::AbstractArray, cycle, kwargs...)
130130
131131
Execute multigrid cycling.
132132
@@ -145,20 +145,20 @@ Keyword Arguments
145145
* log::Bool - return vector of residuals along with solution
146146
147147
"""
148-
function solve(ml::MultiLevel, b::AbstractArray, args...; kwargs...)
148+
function _solve(ml::MultiLevel, b::AbstractArray, args...; kwargs...)
149149
n = length(ml) == 1 ? size(ml.final_A, 1) : size(ml.levels[1].A, 1)
150150
V = promote_type(eltype(ml.workspace), eltype(b))
151151
x = zeros(V, size(b))
152-
return solve!(x, ml, b, args...; kwargs...)
152+
return _solve!(x, ml, b, args...; kwargs...)
153153
end
154-
function solve!(x, ml::MultiLevel, b::AbstractArray{T},
154+
function _solve!(x, ml::MultiLevel, b::AbstractArray{T},
155155
cycle::Cycle = V();
156156
maxiter::Int = 100,
157157
abstol::Real = zero(real(eltype(b))),
158158
reltol::Real = sqrt(eps(real(eltype(b)))),
159159
verbose::Bool = false,
160160
log::Bool = false,
161-
calculate_residual = true) where {T}
161+
calculate_residual = true, kwargs...) where {T}
162162

163163
A = length(ml) == 1 ? ml.final_A : ml.levels[1].A
164164
V = promote_type(eltype(A), eltype(b))
@@ -233,3 +233,28 @@ function __solve!(x, ml, cycle::Cycle, b, lvl)
233233

234234
x
235235
end
236+
237+
### CommonSolve.jl spec
238+
struct AMGSolver{T}
239+
ml::MultiLevel
240+
b::Vector{T}
241+
end
242+
243+
abstract type AMGAlg end
244+
245+
struct RugeStubenAMG <: AMGAlg end
246+
struct SmoothedAggregationAMG <: AMGAlg end
247+
248+
function solve(A::AbstractMatrix, b::Vector, s::AMGAlg, args...; kwargs...)
249+
solt = init(s, A, b, args...; kwargs...)
250+
solve!(solt, args...; kwargs...)
251+
end
252+
function init(::RugeStubenAMG, A, b, args...; kwargs...)
253+
AMGSolver(ruge_stuben(A; kwargs...), b)
254+
end
255+
function init(::SmoothedAggregationAMG, A, b; kwargs...)
256+
AMGSolver(smoothed_aggregation(A; kwargs...), b)
257+
end
258+
function solve!(solt::AMGSolver, args...; kwargs...)
259+
_solve(solt.ml, solt.b, args...; kwargs...)
260+
end

src/preconditioner.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function ldiv!(x, p::Preconditioner, b)
1515
else
1616
x .= b
1717
end
18-
solve!(x, p.ml, b, p.cycle, maxiter = 1, calculate_residual = false)
18+
_solve!(x, p.ml, b, p.cycle, maxiter = 1, calculate_residual = false)
1919
end
2020
mul!(b, p::Preconditioner, x) = mul!(b, p.ml.levels[1].A, x)
2121

test/cycle_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function test_cycles()
1414
ml = method(A)
1515

1616
for cycle in [AlgebraicMultigrid.V(),AlgebraicMultigrid.W(),AlgebraicMultigrid.F()]
17-
x,convhist = solve(ml, b, cycle; reltol = reltol, log = true)
17+
x,convhist = AlgebraicMultigrid._solve(ml, b, cycle; reltol = reltol, log = true)
1818

1919
@debug "number of iterations for $cycle using $method: $(length(convhist))"
2020
@test norm(b - A*x) < reltol * norm(b)

test/runtests.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -120,29 +120,26 @@ fsmoother = GaussSeidel(ForwardSweep())
120120
A = poisson(1000)
121121
A = float.(A)
122122
ml = ruge_stuben(A)
123-
x = solve(ml, A * ones(1000))
123+
x = AlgebraicMultigrid._solve(ml, A * ones(1000))
124124
@test sum(abs2, x - ones(1000)) < 1e-8
125125

126126
ml = ruge_stuben(A, presmoother = fsmoother,
127127
postsmoother = fsmoother)
128-
x = solve(ml, A * ones(1000))
128+
x = AlgebraicMultigrid._solve(ml, A * ones(1000))
129129
@test sum(abs2, x - ones(1000)) < 1e-8
130130

131131

132132
A = include("randlap.jl")
133133

134134
ml = ruge_stuben(A, presmoother = fsmoother,
135135
postsmoother = fsmoother)
136-
x = solve(ml, A * ones(100))
136+
x = AlgebraicMultigrid._solve(ml, A * ones(100))
137137
@test sum(abs2, x - zeros(100)) < 1e-8
138138

139139
ml = ruge_stuben(A)
140-
x = solve(ml, A * ones(100))
140+
x = AlgebraicMultigrid._solve(ml, A * ones(100))
141141
@test sum(abs2, x - zeros(100)) < 1e-6
142142

143-
144-
145-
146143
end
147144

148145
@testset "Preconditioning" begin
@@ -155,7 +152,7 @@ p = aspreconditioner(ml)
155152
b = zeros(n)
156153
b[1] = 1
157154
b[2] = -1
158-
x = solve(p.ml, A * ones(n), maxiter = 1, abstol = 1e-12)
155+
x = AlgebraicMultigrid._solve(ml, A * ones(n), maxiter = 1, abstol = 1e-12)
159156
diff = x - [ 1.88664780e-16, 2.34982727e-16, 2.33917697e-16,
160157
8.77869044e-17, 7.16783490e-17, 1.43415460e-16,
161158
3.69199021e-17, 9.70950385e-17, 4.77034895e-17,
@@ -173,7 +170,9 @@ diff = x - [ 1.88664780e-16, 2.34982727e-16, 2.33917697e-16,
173170
-6.76965535e-16, -7.00643227e-16, -6.23581397e-16,
174171
-7.03016682e-16]
175172
@test sum(abs2, diff) < 1e-8
176-
x = solve(p.ml, b, maxiter = 1, abstol = 1e-12)
173+
x = solve(A, b, RugeStubenAMG(); presmoother = smoother,
174+
postsmoother = smoother,
175+
maxiter = 1, abstol = 1e-12)
177176
diff = x - [ 0.76347046, -0.5498286 , -0.2705487 , -0.15047352, -0.10248021,
178177
0.60292674, -0.11497073, -0.08460548, -0.06931461, 0.38230708,
179178
-0.055664 , -0.04854558, -0.04577031, 0.09964325, 0.01825624,
@@ -214,7 +213,7 @@ diff = x - [0.823762, -0.537478, -0.306212, -0.19359, -0.147621, 0.685002,
214213
0.0511691, 0.0502043, 0.0498349, 0.0498134]
215214
@test sum(abs2, diff) < 1e-8
216215

217-
x = solve(ml, b, maxiter = 1, reltol = 1e-12)
216+
x = AlgebraicMultigrid._solve(ml, b, maxiter = 1, reltol = 1e-12)
218217
diff = x - [0.775725, -0.571202, -0.290989, -0.157001, -0.106981, 0.622652,
219218
-0.122318, -0.0891874, -0.0709834, 0.392621, -0.055544, -0.0507485,
220219
-0.0466376, 0.107175, 0.0267468, -0.0200843, -0.0282827, -0.0299929,
@@ -240,7 +239,7 @@ for (T,V) in ((Float64, Float64), (Float32,Float32),
240239
ml = smoothed_aggregation(a)
241240
b = V.(b)
242241
c = cg(a, b, maxiter = 10)
243-
@test eltype(solve(ml, b)) == eltype(c)
242+
@test eltype(AlgebraicMultigrid._solve(ml, b)) == eltype(c)
244243
end
245244

246245
end
@@ -308,15 +307,16 @@ for sz in [10, 5, 2]
308307
end
309308

310309
# Issue #46
311-
for f in (smoothed_aggregation, ruge_stuben)
310+
for f in ((smoothed_aggregation, SmoothedAggregationAMG),
311+
(ruge_stuben, RugeStubenAMG))
312312

313313
a = load("bug.jld2")["G"]
314-
ml = f(a)
314+
ml = f[1](a)
315315
p = aspreconditioner(ml)
316316
b = zeros(size(a,1))
317317
b[1] = 1
318318
b[2] = -1
319-
@test sum(abs2, a * solve(ml, b) - b) < 1e-10
319+
@test sum(abs2, a * solve(a, b, f[2]()) - b) < 1e-10
320320
@test sum(abs2, a * cg(a, b, Pl = p, maxiter = 1000) - b) < 1e-10
321321

322322
end
@@ -327,4 +327,4 @@ end
327327
X = poisson(27_000)+24.0*I
328328
ml = ruge_stuben(X)
329329
b = rand(27_000)
330-
@test solve(ml, b, reltol = 1e-10) X \ b rtol = 1e-10
330+
@test AlgebraicMultigrid._solve(ml, b, reltol = 1e-10) X \ b rtol = 1e-10

0 commit comments

Comments
 (0)