Skip to content

Commit aee7c2d

Browse files
committed
callable Pinv and some inplace
1 parent d193468 commit aee7c2d

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

src/aggregation.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ function smoothed_aggregation(A::SparseMatrixCSC{T,V},
1010
max_coarse = 10,
1111
diagonal_dominance = false,
1212
keep = false,
13-
coarse_solver = Pinv()) where {T,V}
14-
13+
coarse_solver = Pinv) where {T,V}
1514

1615
n = size(A, 1)
1716
# B = kron(ones(n, 1), eye(1))
@@ -42,7 +41,7 @@ function smoothed_aggregation(A::SparseMatrixCSC{T,V},
4241
#=A, B = extend_hierarchy!(levels, strength, aggregate, smooth,
4342
improve_candidates, diagonal_dominance,
4443
keep, A, B, symmetry)=#
45-
MultiLevel(levels, A, presmoother, postsmoother)
44+
MultiLevel(levels, A, coarse_solver(A), presmoother, postsmoother)
4645
end
4746

4847
struct HermitianSymmetry

src/multilevel.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@ struct MultiLevel{S, Pre, Post, Ti, Tv}
1313
end
1414

1515
abstract type CoarseSolver end
16-
struct Pinv <: CoarseSolver
16+
struct Pinv{T} <: CoarseSolver
17+
pinvA::Matrix{T}
18+
Pinv(A) = new{eltype(A)}(pinv(Matrix(A)))
1719
end
20+
(p::Pinv)(x, b) = mul!(x, p.pinvA, b)
1821

1922
MultiLevel(l::Vector{Level{Ti,Tv}}, A::SparseMatrixCSC{Ti,Tv}, presmoother, postsmoother) where {Ti,Tv} =
20-
MultiLevel(l, A, Pinv(), presmoother, postsmoother)
23+
MultiLevel(l, A, Pinv(A), presmoother, postsmoother)
2124
Base.length(ml) = length(ml.levels) + 1
2225

2326
function Base.show(io::IO, ml::MultiLevel)
@@ -113,9 +116,9 @@ function solve(ml::MultiLevel, b::AbstractVector{T},
113116
lvl = 1
114117
while length(residuals) <= maxiter && residuals[end] > tol
115118
if length(ml) == 1
116-
x = coarse_solver(ml.coarse_solver, A, b)
119+
ml.coarse_solver(x, b)
117120
else
118-
x = __solve(cycle, ml, x, b, lvl)
121+
__solve!(x, ml, cycle, b, lvl)
119122
end
120123
push!(residuals, T(norm(b - A * x)))
121124
end
@@ -127,7 +130,7 @@ function solve(ml::MultiLevel, b::AbstractVector{T},
127130
return x
128131
end
129132
end
130-
function __solve(v::V, ml, x, b, lvl)
133+
function __solve!(x, ml, v::V, b, lvl)
131134

132135
A = ml.levels[lvl].A
133136
ml.presmoother(A, x, b)
@@ -137,9 +140,9 @@ function __solve(v::V, ml, x, b, lvl)
137140
coarse_x = zeros(eltype(coarse_b), size(coarse_b))
138141

139142
if lvl == length(ml.levels)
140-
coarse_x = coarse_solver(ml.coarse_solver, ml.final_A, coarse_b)
143+
ml.coarse_solver(coarse_x, coarse_b)
141144
else
142-
coarse_x = __solve(v, ml, coarse_x, coarse_b, lvl + 1)
145+
coarse_x = __solve!(coarse_x, ml, v, coarse_b, lvl + 1)
143146
end
144147

145148
x .+= ml.levels[lvl].P * coarse_x
@@ -148,5 +151,3 @@ function __solve(v::V, ml, x, b, lvl)
148151

149152
x
150153
end
151-
152-
coarse_solver(::Pinv, A, b) = pinv(Matrix(A)) * b

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Compat, Compat.Test, Compat.LinearAlgebra
22
using Compat.SparseArrays, Compat.DelimitedFiles, Compat.Random
33
using IterativeSolvers, FileIO, AlgebraicMultigrid
4-
import AlgebraicMultigrid: V, coarse_solver, Pinv, Classical
4+
import AlgebraicMultigrid: Pinv, Classical
55

66
include("sa_tests.jl")
77

@@ -68,7 +68,7 @@ end
6868
@testset "Coarse Solver" begin
6969
A = float.(poisson(10))
7070
b = A * ones(10)
71-
@test sum(abs2, coarse_solver(Pinv(), A, b) - ones(10)) < 1e-6
71+
@test sum(abs2, Pinv(A)(similar(b), b) - ones(10)) < 1e-6
7272
end
7373

7474
@testset "Multilevel" begin

0 commit comments

Comments
 (0)