Skip to content

Commit bda20fe

Browse files
Merge pull request #33 from vpuri3/vp-precond
default preconditioners
2 parents e6e42b0 + 730b501 commit bda20fe

File tree

5 files changed

+137
-34
lines changed

5 files changed

+137
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1919
[compat]
2020
ArrayInterface = "3"
2121
IterativeSolvers = "0.9.2"
22-
Krylov = "0.7"
22+
Krylov = "0.7.9"
2323
KrylovKit = "0.5"
2424
RecursiveFactorization = "0.2"
2525
Reexport = "1"

src/LinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ module LinearSolve
33
using ArrayInterface
44
using RecursiveFactorization
55
using Base: cache_dependencies, Bool
6+
import Base: eltype, adjoint, inv
67
using LinearAlgebra
8+
using IterativeSolvers:Identity
79
using SparseArrays
810
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
911
using Setfield

src/common.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ function set_cacheval(cache::LinearCache, alg_cache)
4444
return cache
4545
end
4646

47+
function set_prec(cache, Pl, Pr)
48+
@set! cache.Pl = Pl
49+
@set! cache.Pr = Pr
50+
return cache
51+
end
52+
4753
init_cacheval(alg::Union{SciMLLinearSolveAlgorithm,Nothing}, A, b, u) = nothing
4854

4955
SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)
@@ -54,19 +60,20 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
5460
reltol=eps(eltype(prob.A)),
5561
maxiters=length(prob.b),
5662
verbose=false,
63+
Pl = nothing,
64+
Pr = nothing,
5765
kwargs...,
5866
)
5967
@unpack A, b, u0, p = prob
6068

61-
u0 = (u0 === nothing) ? zero(b) : u0
69+
u0 = (u0 !== nothing) ? u0 : zero(b)
70+
Pl = (Pl !== nothing) ? Pl : Identity()
71+
Pr = (Pr !== nothing) ? Pr : Identity()
6272

6373
cacheval = init_cacheval(alg, A, b, u0)
6474
isfresh = cacheval === nothing
6575
Tc = isfresh ? Any : typeof(cacheval)
6676

67-
Pl = LinearAlgebra.I
68-
Pr = LinearAlgebra.I
69-
7077
A = alias_A ? A : deepcopy(A)
7178
b = alias_b ? b : deepcopy(b)
7279

src/wrappers.jl

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,61 @@
11

2-
#TODO: composed preconditioners, preconditioner setter for cache,
3-
# detailed tests for wrappers
4-
52
## Preconditioners
63

7-
struct ScaleVector{T}
8-
s::T
9-
isleft::Bool
4+
scaling_preconditioner(s) = I * s , I * (1/s)
5+
6+
struct ComposePreconditioner{Ti,To}
7+
inner::Ti
8+
outer::To
109
end
1110

12-
function LinearAlgebra.ldiv!(v::ScaleVector, x)
11+
Base.eltype(A::ComposePreconditioner) = promote_type(eltype(A.inner), eltype(A.outer))
12+
Base.adjoint(A::ComposePreconditioner) = ComposePreconditioner(A.outer', A.inner')
13+
Base.inv(A::ComposePreconditioner) = InvComposePreconditioner(A)
14+
15+
function LinearAlgebra.ldiv!(A::ComposePreconditioner, x)
16+
@unpack inner, outer = A
17+
18+
ldiv!(inner, x)
19+
ldiv!(outer, x)
1320
end
1421

15-
function LinearAlgebra.ldiv!(y, v::ScaleVector, x)
22+
function LinearAlgebra.ldiv!(y, A::ComposePreconditioner, x)
23+
@unpack inner, outer = A
24+
25+
ldiv!(y, inner, x)
26+
ldiv!(outer, y)
1627
end
1728

18-
struct ComposePreconditioner{Ti,To}
19-
inner::Ti
20-
outer::To
21-
isleft::Bool
29+
struct InvComposePreconditioner{Tp <: ComposePreconditioner}
30+
P::Tp
2231
end
2332

24-
function LinearAlgebra.ldiv!(v::ComposePreconditioner, x)
25-
@unpack inner, outer, isleft = v
33+
InvComposePreconditioner(inner, outer) = InvComposePreconditioner(ComposePreconditioner(inner, outer))
34+
35+
Base.eltype(A::InvComposePreconditioner) = Base.eltype(A.P)
36+
Base.adjoint(A::InvComposePreconditioner) = InvComposePreconditioner(A.P')
37+
Base.inv(A::InvComposePreconditioner) = deepcopy(A.P)
38+
39+
function LinearAlgebra.mul!(y, A::InvComposePreconditioner, x)
40+
@unpack P = A
41+
ldiv!(y, P, x)
2642
end
2743

28-
function LinearAlgebra.ldiv!(y, v::ComposePreconditioner, x)
29-
@unpack inner, outer, isleft = v
44+
function get_preconditioner(Pi, Po)
45+
46+
ifPi = Pi !== Identity()
47+
ifPo = Po !== Identity()
48+
49+
P =
50+
if ifPi & ifPo
51+
ComposePreconditioner(Pi, Po)
52+
elseif ifPi | ifPo
53+
ifPi ? Pi : Po
54+
else
55+
Identity()
56+
end
57+
58+
return P
3059
end
3160

3261
## Krylov.jl
@@ -41,10 +70,14 @@ struct KrylovJL{F,Tl,Tr,I,A,K} <: AbstractKrylovSubspaceMethod
4170
kwargs::K
4271
end
4372

44-
function KrylovJL(args...; KrylovAlg = Krylov.gmres!, Pl=I, Pr=I,
73+
function KrylovJL(args...; KrylovAlg = Krylov.gmres!,
74+
Pl=nothing, Pr=nothing,
4575
gmres_restart=0, window=0,
4676
kwargs...)
4777

78+
Pl = (Pl === nothing) ? Identity() : Pl
79+
Pr = (Pr === nothing) ? Identity() : Pr
80+
4881
return KrylovJL(KrylovAlg, Pl, Pr, gmres_restart, window,
4982
args, kwargs)
5083
end
@@ -132,6 +165,12 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
132165
cache = set_cacheval(cache, solver)
133166
end
134167

168+
M = get_preconditioner(alg.Pl, cache.Pl)
169+
N = get_preconditioner(alg.Pr, cache.Pr)
170+
171+
M = (M === Identity()) ? I : inv(M)
172+
N = (N === Identity()) ? I : inv(N)
173+
135174
atol = cache.abstol
136175
rtol = cache.reltol
137176
itmax = cache.maxiters
@@ -142,20 +181,20 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
142181
alg.kwargs...)
143182

144183
if cache.cacheval isa Krylov.CgSolver
145-
alg.Pr != LinearAlgebra.I &&
184+
N !== I &&
146185
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
147-
Krylov.solve!(args...; M=alg.Pl,
186+
Krylov.solve!(args...; M=M,
148187
kwargs...)
149188
elseif cache.cacheval isa Krylov.GmresSolver
150-
Krylov.solve!(args...; M=alg.Pl, N=alg.Pr,
189+
Krylov.solve!(args...; M=M, N=N,
151190
kwargs...)
152191
elseif cache.cacheval isa Krylov.BicgstabSolver
153-
Krylov.solve!(args...; M=alg.Pl, N=alg.Pr,
192+
Krylov.solve!(args...; M=M, N=N,
154193
kwargs...)
155194
elseif cache.cacheval isa Krylov.MinresSolver
156-
alg.Pr != LinearAlgebra.I &&
195+
N !== I &&
157196
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
158-
Krylov.solve!(args...; M=alg.Pl,
197+
Krylov.solve!(args...; M=M,
159198
kwargs...)
160199
else
161200
Krylov.solve!(args...; kwargs...)
@@ -177,9 +216,12 @@ end
177216

178217
function IterativeSolversJL(args...;
179218
generate_iterator = IterativeSolvers.gmres_iterable!,
180-
Pl=IterativeSolvers.Identity(),
181-
Pr=IterativeSolvers.Identity(),
219+
Pl=nothing, Pr=nothing,
182220
gmres_restart=0, kwargs...)
221+
222+
Pl = (Pl === nothing) ? Identity() : Pl
223+
Pr = (Pr === nothing) ? Identity() : Pr
224+
183225
return IterativeSolversJL(generate_iterator, Pl, Pr, gmres_restart,
184226
args, kwargs)
185227
end
@@ -204,8 +246,8 @@ IterativeSolversJL_MINRES(args...;kwargs...) =
204246
function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
205247
@unpack A, b, u = cache
206248

207-
Pl = (alg.Pl == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pl
208-
Pr = (alg.Pr == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pr
249+
Pl = get_preconditioner(alg.Pl, cache.Pl)
250+
Pr = get_preconditioner(alg.Pr, cache.Pr)
209251

210252
abstol = cache.abstol
211253
reltol = cache.reltol
@@ -218,15 +260,15 @@ function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
218260
alg.kwargs...)
219261

220262
iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
221-
Pr != IterativeSolvers.Identity() &&
263+
Pr !== Identity() &&
222264
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
223265
alg.generate_iterator(u, A, b, Pl;
224266
kwargs...)
225267
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
226268
alg.generate_iterator(u, A, b; Pl=Pl, Pr=Pr, restart=restart,
227269
kwargs...)
228270
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
229-
Pr != IterativeSolvers.Identity() &&
271+
Pr !== Identity() &&
230272
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
231273
alg.generate_iterator(u, A, b, alg.args...; Pl=Pl,
232274
abstol=abstol, reltol=reltol,

test/runtests.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ function test_interface(alg, prob1, prob2)
3434
return
3535
end
3636

37+
@testset "LinearSolve" begin
38+
3739
@testset "Default Linear Solver" begin
3840
test_interface(nothing, prob1, prob2)
3941

@@ -123,3 +125,53 @@ end
123125
end
124126
end
125127
end
128+
129+
@testset "Preconditioners" begin
130+
@testset "scaling_preconditioner" begin
131+
s = rand()
132+
133+
x = rand(n,n)
134+
y = rand(n,n)
135+
136+
Pl, Pr = LinearSolve.scaling_preconditioner(s)
137+
138+
mul!(y, Pl, x); @test y s * x
139+
mul!(y, Pr, x); @test y s \ x
140+
141+
y .= x; ldiv!(Pl, x); @test x s \ y
142+
y .= x; ldiv!(Pr, x); @test x s * y
143+
144+
ldiv!(y, Pl, x); @test y s \ x
145+
ldiv!(y, Pr, x); @test y s * x
146+
147+
end
148+
149+
@testset "ComposePreconditioenr" begin
150+
s1 = rand()
151+
s2 = rand()
152+
153+
x = rand(n,n)
154+
y = rand(n,n)
155+
156+
P1, _ = LinearSolve.scaling_preconditioner(s1)
157+
P2, _ = LinearSolve.scaling_preconditioner(s2)
158+
159+
P = LinearSolve.ComposePreconditioner(P1,P2)
160+
Pi = LinearSolve.InvComposePreconditioner(P)
161+
162+
@test Pi == LinearSolve.InvComposePreconditioner(P1, P2)
163+
@test Pi == inv(P)
164+
@test P == inv(Pi)
165+
@test Pi' == inv(P')
166+
167+
# ComposePreconditioner
168+
ldiv!(y, P, x); @test y ldiv!(P2, ldiv!(P1, x))
169+
y .= x; ldiv!(P, x); @test x ldiv!(P2, ldiv!(P1, y))
170+
171+
# InvComposePreconditioner
172+
mul!(y, Pi, x); @test y ldiv!(P2, ldiv!(P1, x))
173+
174+
end
175+
end
176+
177+
end # testset

0 commit comments

Comments
 (0)