Skip to content

Commit 730b501

Browse files
committed
update
1 parent 1553c34 commit 730b501

File tree

4 files changed

+48
-37
lines changed

4 files changed

+48
-37
lines changed

src/LinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using RecursiveFactorization
55
using Base: cache_dependencies, Bool
66
import Base: eltype, adjoint, inv
77
using LinearAlgebra
8+
using IterativeSolvers:Identity
89
using SparseArrays
910
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
1011
using Setfield

src/common.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,20 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
6060
reltol=eps(eltype(prob.A)),
6161
maxiters=length(prob.b),
6262
verbose=false,
63-
scale=one(eltype(prob.A)),
63+
Pl = nothing,
64+
Pr = nothing,
6465
kwargs...,
6566
)
6667
@unpack A, b, u0, p = prob
6768

68-
u0 = (u0 === nothing) ? zero(b) : u0
69+
u0 = (u0 !== nothing) ? u0 : zero(b)
70+
Pl = (Pl !== nothing) ? Pl : Identity()
71+
Pr = (Pr !== nothing) ? Pr : Identity()
6972

7073
cacheval = init_cacheval(alg, A, b, u0)
7174
isfresh = cacheval === nothing
7275
Tc = isfresh ? Any : typeof(cacheval)
7376

74-
Pl = scaling_preconditioner(scale, true)
75-
Pr = scaling_preconditioner(scale, false)
76-
7777
A = alias_A ? A : deepcopy(A)
7878
b = alias_b ? b : deepcopy(b)
7979

src/wrappers.jl

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
## Preconditioners
33

4-
scaling_preconditioner(s, isleft) = isleft ? I * s : I * (1/s)
4+
scaling_preconditioner(s) = I * s , I * (1/s)
55

66
struct ComposePreconditioner{Ti,To}
77
inner::Ti
@@ -41,6 +41,23 @@ function LinearAlgebra.mul!(y, A::InvComposePreconditioner, x)
4141
ldiv!(y, P, x)
4242
end
4343

44+
function get_preconditioner(Pi, Po)
45+
46+
ifPi = Pi !== Identity()
47+
ifPo = Po !== Identity()
48+
49+
P =
50+
if ifPi & ifPo
51+
ComposePreconditioner(Pi, Po)
52+
elseif ifPi | ifPo
53+
ifPi ? Pi : Po
54+
else
55+
Identity()
56+
end
57+
58+
return P
59+
end
60+
4461
## Krylov.jl
4562

4663
struct KrylovJL{F,Tl,Tr,I,A,K} <: AbstractKrylovSubspaceMethod
@@ -53,10 +70,14 @@ struct KrylovJL{F,Tl,Tr,I,A,K} <: AbstractKrylovSubspaceMethod
5370
kwargs::K
5471
end
5572

56-
function KrylovJL(args...; KrylovAlg = Krylov.gmres!, Pl=I, Pr=I,
73+
function KrylovJL(args...; KrylovAlg = Krylov.gmres!,
74+
Pl=nothing, Pr=nothing,
5775
gmres_restart=0, window=0,
5876
kwargs...)
5977

78+
Pl = (Pl === nothing) ? Identity() : Pl
79+
Pr = (Pr === nothing) ? Identity() : Pr
80+
6081
return KrylovJL(KrylovAlg, Pl, Pr, gmres_restart, window,
6182
args, kwargs)
6283
end
@@ -144,16 +165,11 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
144165
cache = set_cacheval(cache, solver)
145166
end
146167

147-
M = I # left precond
148-
N = I # right precond
168+
M = get_preconditioner(alg.Pl, cache.Pl)
169+
N = get_preconditioner(alg.Pr, cache.Pr)
149170

150-
if (cache.Pl != I) | (alg.Pl != I)
151-
M = InvComposePreconditioner(alg.Pl, cache.Pl)
152-
end
153-
154-
if (cache.Pr != I) | (alg.Pr != I)
155-
N = InvComposePreconditioner(alg.Pr, cache.Pr)
156-
end
171+
M = (M === Identity()) ? I : inv(M)
172+
N = (N === Identity()) ? I : inv(N)
157173

158174
atol = cache.abstol
159175
rtol = cache.reltol
@@ -165,7 +181,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
165181
alg.kwargs...)
166182

167183
if cache.cacheval isa Krylov.CgSolver
168-
N != I &&
184+
N !== I &&
169185
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
170186
Krylov.solve!(args...; M=M,
171187
kwargs...)
@@ -176,7 +192,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
176192
Krylov.solve!(args...; M=M, N=N,
177193
kwargs...)
178194
elseif cache.cacheval isa Krylov.MinresSolver
179-
N != I &&
195+
N !== I &&
180196
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
181197
Krylov.solve!(args...; M=M,
182198
kwargs...)
@@ -200,9 +216,12 @@ end
200216

201217
function IterativeSolversJL(args...;
202218
generate_iterator = IterativeSolvers.gmres_iterable!,
203-
Pl=IterativeSolvers.Identity(),
204-
Pr=IterativeSolvers.Identity(),
219+
Pl=nothing, Pr=nothing,
205220
gmres_restart=0, kwargs...)
221+
222+
Pl = (Pl === nothing) ? Identity() : Pl
223+
Pr = (Pr === nothing) ? Identity() : Pr
224+
206225
return IterativeSolversJL(generate_iterator, Pl, Pr, gmres_restart,
207226
args, kwargs)
208227
end
@@ -227,16 +246,8 @@ IterativeSolversJL_MINRES(args...;kwargs...) =
227246
function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
228247
@unpack A, b, u = cache
229248

230-
Pl = IterativeSolvers.Identity()
231-
Pr = IterativeSolvers.Identity()
232-
233-
if (cache.Pl != I) | (alg.Pl != IterativeSolvers.Identity())
234-
Pl = ComposePreconditioner(alg.Pl, cache.Pl)
235-
end
236-
237-
if (cache.Pr != I) | (alg.Pr != IterativeSolvers.Identity())
238-
Pr = ComposePreconditioner(alg.Pr, cache.Pr)
239-
end
249+
Pl = get_preconditioner(alg.Pl, cache.Pl)
250+
Pr = get_preconditioner(alg.Pr, cache.Pr)
240251

241252
abstol = cache.abstol
242253
reltol = cache.reltol
@@ -249,15 +260,15 @@ function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
249260
alg.kwargs...)
250261

251262
iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
252-
Pr != IterativeSolvers.Identity() &&
263+
Pr !== Identity() &&
253264
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
254265
alg.generate_iterator(u, A, b, Pl;
255266
kwargs...)
256267
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
257268
alg.generate_iterator(u, A, b; Pl=Pl, Pr=Pr, restart=restart,
258269
kwargs...)
259270
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
260-
Pr != IterativeSolvers.Identity() &&
271+
Pr !== Identity() &&
261272
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
262273
alg.generate_iterator(u, A, b, alg.args...; Pl=Pl,
263274
abstol=abstol, reltol=reltol,

test/runtests.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ end
133133
x = rand(n,n)
134134
y = rand(n,n)
135135

136-
Pl = LinearSolve.scaling_preconditioner(s, true)
137-
Pr = LinearSolve.scaling_preconditioner(s, false)
136+
Pl, Pr = LinearSolve.scaling_preconditioner(s)
138137

139138
mul!(y, Pl, x); @test y s * x
140139
mul!(y, Pr, x); @test y s \ x
@@ -154,13 +153,13 @@ end
154153
x = rand(n,n)
155154
y = rand(n,n)
156155

157-
P1 = LinearSolve.scaling_preconditioner(s1, true)
158-
P2 = LinearSolve.scaling_preconditioner(s2, true)
156+
P1, _ = LinearSolve.scaling_preconditioner(s1)
157+
P2, _ = LinearSolve.scaling_preconditioner(s2)
159158

160159
P = LinearSolve.ComposePreconditioner(P1,P2)
161160
Pi = LinearSolve.InvComposePreconditioner(P)
162161

163-
@test Pi == LinearSolve.InvComposePreconditioner(P1,P2)
162+
@test Pi == LinearSolve.InvComposePreconditioner(P1, P2)
164163
@test Pi == inv(P)
165164
@test P == inv(Pi)
166165
@test Pi' == inv(P')

0 commit comments

Comments
 (0)