Skip to content

Commit 3af8a74

Browse files
LinearSolve.jl integration and remove some dead code.
1 parent e567b00 commit 3af8a74

File tree

4 files changed

+45
-7
lines changed

4 files changed

+45
-7
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.5.1"
55
[deps]
66
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
89
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
910
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1011
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

src/AlgebraicMultigrid.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,13 @@ module AlgebraicMultigrid
22

33
using Reexport
44
using LinearAlgebra
5+
using LinearSolve
56
using SparseArrays, Printf
6-
using Base.Threads
77
@reexport import CommonSolve: solve, solve!, init
88
using Reexport
99

1010
using LinearAlgebra: rmul!
1111

12-
# const mul! = A_mul_B!
13-
14-
const MT = false
15-
const AMG = AlgebraicMultigrid
16-
1712
include("utils.jl")
1813
export approximate_spectral_radius
1914

src/multilevel.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ function coarse_b!(m::MultiLevelWorkspace{TX, bs}, n) where {TX, bs}
5252
end
5353

5454
abstract type CoarseSolver end
55+
56+
"""
57+
Pinv{T} <: CoarseSolver
58+
59+
Moore-Penrose pseudo inverse coarse solver. Calls `pinv`
60+
"""
5561
struct Pinv{T} <: CoarseSolver
5662
pinvA::Matrix{T}
5763
Pinv{T}(A) where T = new{T}(pinv(Matrix(A)))
@@ -61,6 +67,43 @@ Base.show(io::IO, p::Pinv) = print(io, "Pinv")
6167

6268
(p::Pinv)(x, b) = mul!(x, p.pinvA, b)
6369

70+
# This one is used internally.
71+
"""
72+
LinearSolveWrapperInternal <: CoarseSolver
73+
74+
Helper to allow the usage of LinearSolve.jl solvers for the coarse-level solve. Constructed via `LinearSolveWrapper`.
75+
"""
76+
struct LinearSolveWrapperInternal{LC <: LinearSolve.LinearCache} <: CoarseSolver
77+
linsolve::LC
78+
function LinearSolveWrapperInternal(A, alg::LinearSolve.SciMLLinearSolveAlgorithm)
79+
rhs_tmp = zeros(eltype(A), size(A,1))
80+
u_tmp = zeros(eltype(A), size(A,2))
81+
linprob = LinearProblem(A, rhs_tmp; u0 = u_tmp, alias_A = false, alias_b = false)
82+
linsolve = init(linprob, alg)
83+
new{typeof(linsolve)}(linsolve)
84+
end
85+
end
86+
87+
function (p::LinearSolveWrapperInternal{LC})(x, b) where {LC <: LinearSolve.LinearCache}
88+
for i 1:size(b, 2)
89+
# Update right hand side
90+
p.linsolve.b = b[:, i]
91+
# Solve for x and update
92+
x[:, i] = solve!(p.linsolve).u
93+
end
94+
end
95+
96+
# This one simplifies passing of LinearSolve.jl algorithms into AlgebraicMultigrid.jl as coarse solvers.
97+
"""
98+
LinearSolveWrapper <: CoarseSolver
99+
100+
Helper to allow the usage of LinearSolve.jl solvers for the coarse-level solve.
101+
"""
102+
struct LinearSolveWrapper <: CoarseSolver
103+
alg::LinearSolve.SciMLLinearSolveAlgorithm
104+
end
105+
(p::LinearSolveWrapper)(A::AbstractMatrix) = LinearSolveWrapperInternal(A, p.alg)
106+
64107
Base.length(ml::MultiLevel) = length(ml.levels) + 1
65108

66109
function Base.show(io::IO, ml::MultiLevel)

src/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ end
9999
find_breakdown(::Type{Float64}) = eps(Float64) * 10^6
100100
find_breakdown(::Type{Float32}) = eps(Float64) * 10^3
101101

102-
using Base.Threads
103102
#=function mul!(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
104103
A.n == size(B, 1) || throw(DimensionMismatch())
105104
A.m == size(C, 1) || throw(DimensionMismatch())

0 commit comments

Comments
 (0)