Skip to content

Commit 4c9b487

Browse files
committed
iterative solvers wrapper almost done. add tests
1 parent 308b3dd commit 4c9b487

File tree

4 files changed

+55
-47
lines changed

4 files changed

+55
-47
lines changed

src/LinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ include("wrappers.jl")
2626
export DefaultLinSolve
2727
export LUFactorization, SVDFactorization, QRFactorization
2828
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB,
29+
KrylovJL_MINRES,
2930
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
30-
IterativeSolversJL_BICGSTAB
31+
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES
3132

3233
end

src/common.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
4848
)
4949
@unpack A, b, u0, p = prob
5050

51-
if u0 == nothing
52-
u0 = zero(b)
53-
end
51+
u0 = (u0 == nothing) ? zero(b) : u0
5452

5553
cacheval = init_cacheval(alg, A, b, u0)
5654
isfresh = cacheval == nothing

src/wrappers.jl

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,4 @@
11

2-
"""
3-
TODO
4-
- account for standard kwargs: abstol, reltol, maxiter, restart, window
5-
- KrylovJL: memory, window <-- Solver, itmax, atol, rtol <-- Solve
6-
- IterativeSolversJL: restart
7-
"""
8-
9-
102
## Preconditioners
113

124
struct ScaleVector{T}
@@ -36,7 +28,7 @@ end
3628

3729
## Krylov.jl
3830

39-
struct KrylovJL{F,Tl,Tr,A,K, T, I} <: SciMLLinearSolveAlgorithm
31+
struct KrylovJL{F,Tl,Tr,T,I,A,K} <: SciMLLinearSolveAlgorithm
4032
KrylovAlg::F
4133
Pl::Tl
4234
Pr::Tr
@@ -59,9 +51,14 @@ function KrylovJL(args...; KrylovAlg = Krylov.gmres!, Pl=I, Pr=I,
5951
args, kwargs)
6052
end
6153

62-
KrylovJL_CG(args...;kwargs...) = KrylovJL(Krylov.cg!, args...; kwargs...)
63-
KrylovJL_GMRES(args...;kwargs...) = KrylovJL(Krylov.gmres!, args...; kwargs...)
64-
KrylovJL_BICGSTAB(args...;kwargs...) = KrylovJL(Krylov.bicgstab!, args...; kwargs...)
54+
KrylovJL_CG(args...;kwargs...) =
55+
KrylovJL(Krylov.cg!, args...; kwargs...)
56+
KrylovJL_GMRES(args...;kwargs...) =
57+
KrylovJL(Krylov.gmres!, args...; kwargs...)
58+
KrylovJL_BICGSTAB(args...;kwargs...) =
59+
KrylovJL(Krylov.bicgstab!, args...; kwargs...)
60+
KrylovJL_MINRES(args...;kwargs...) =
61+
KrylovJL(Krylov.minres!, args...; kwargs...)
6562

6663
const KrylovJL_solvers = Dict(
6764
(Krylov.lsmr! => Krylov.LsmrSolver ),
@@ -101,8 +98,8 @@ function init_cacheval(alg::KrylovJL, A, b, u)
10198

10299
KS = KrylovJL_solvers[alg.KrylovAlg]
103100

104-
solver =
105-
if (KS === Krylov.DqgmresSolver ||
101+
solver = if(
102+
KS === Krylov.DqgmresSolver ||
106103
KS === Krylov.DiomSolver ||
107104
KS === Krylov.GmresSolver ||
108105
KS === Krylov.FormSovler
@@ -135,7 +132,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
135132
maxiter = (alg.reltol == 0) ? length(cache.b) : alg.maxiter
136133

137134
Krylov.solve!(cache.cacheval, cache.A, cache.b;
138-
M=cache.Pl, N=cache.Pr,
135+
M=alg.Pl, N=alg.Pr,
139136
atol = abstol, rtol = reltol, itmax = maxiter,
140137
alg.kwargs...)
141138

@@ -144,24 +141,27 @@ end
144141

145142
## IterativeSolvers.jl
146143

147-
struct IterativeSolversJL{F,Tl,Tr,A,K} <: SciMLLinearSolveAlgorithm
144+
struct IterativeSolversJL{F,Tl,Tr,T,I,A,K} <: SciMLLinearSolveAlgorithm
148145
generate_iterator::F
149146
Pl::Tl
150147
Pr::Tr
148+
abstol::T
149+
reltol::T
150+
maxiter::I
151+
restart::I
151152
args::A
152153
kwargs::K
153-
# abstol::T
154-
# reltol::T
155-
# maxiter::I
156-
# restart::I
157154
end
158155

159156
function IterativeSolversJL(args...;
157+
generate_iterator = IterativeSolvers.gmres_iterable!,
160158
Pl=IterativeSolvers.Identity(),
161159
Pr=IterativeSolvers.Identity(),
162-
generate_iterator = IterativeSolvers.gmres_iterable!,
160+
abstol=0.0, reltol=0.0, maxiter=0, restart=0,
163161
kwargs...)
164-
return IterativeSolversJL(generate_iterator, Pl, Pr, args, kwargs)
162+
return IterativeSolversJL(generate_iterator, Pl, Pr,
163+
abstol, reltol, maxiter, restart,
164+
args, kwargs)
165165
end
166166

167167
IterativeSolversJL_CG(args...; kwargs...) =
@@ -170,43 +170,51 @@ IterativeSolversJL_GMRES(args...;kwargs...) =
170170
IterativeSolversJL(IterativeSolvers.gmres_iterable!, args...; kwargs...)
171171
IterativeSolversJL_BICGSTAB(args...;kwargs...) =
172172
IterativeSolversJL(IterativeSolvers.bicgstabl_iterator!, args...;kwargs...)
173+
IterativeSolversJL_MINRES(args...;kwargs...) =
174+
IterativeSolversJL(IterativeSolvers.minres_iterable!, args...;kwargs...)
173175

174-
function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr)
175-
Pl = (Pl == LinearAlgebra.I) ? IterativeSolvers.Identity() : Pl
176-
Pr = (Pr == LinearAlgebra.I) ? IterativeSolvers.Identity() : Pr
176+
function init_cacheval(alg::IterativeSolversJL, A, b, u)
177+
Pl = (alg.Pl == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pl
178+
Pr = (alg.Pr == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pr
177179

178-
# standard kwargs: abstol, reltol
179-
#
180-
# cg: kw: maxiter
181-
# gmres: kw: restart=20, maxiter=length(b)
182-
# bicgstabl: kw: max_mv_products
180+
abstol = (alg.abstol == 0) ? eps(eltype(b)) : alg.abstol
181+
reltol = (alg.reltol == 0) ? eps(eltype(b)) : alg.reltol
182+
maxiter = (alg.reltol == 0) ? length(b) : alg.maxiter
183183

184-
cacheval = if alg.generate_iterator === IterativeSolvers.cg_iterator!
184+
iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
185185
Pr != IterativeSolvers.Identity() &&
186-
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
187-
alg.generate_iterator(u, A, b, Pl; alg.kwargs...)
186+
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
187+
alg.generate_iterator(u, A, b, Pl;
188+
abstol=abstol, reltol=reltol, maxiter=maxiter,
189+
alg.kwargs...)
188190
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
189-
alg.generate_iterator(u, A, b; Pl=Pl, Pr=Pr, alg.kwargs...)
191+
alg.generate_iterator(u, A, b; Pl=Pl, Pr=Pr,
192+
abstol=abstol, reltol=reltol, maxiter=maxiter,
193+
alg.kwargs...)
190194
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
191195
Pr != IterativeSolvers.Identity() &&
192-
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
193-
alg.generate_iterator(u, A, b, alg.args...; Pl=Pl, alg.kwargs...)
194-
else
195-
alg.generate_iterator(u, A, b, alg.args...; alg.kwargs...)
196+
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
197+
alg.generate_iterator(u, A, b, alg.args...; Pl=Pl,
198+
abstol=abstol, reltol=reltol,
199+
max_mv_products=maxiter*2,
200+
alg.kwargs...)
201+
else # minres, qmr
202+
alg.generate_iterator(u, A, b, alg.args...;
203+
abstol=abstol, reltol=reltol, maxiter=maxiter,
204+
alg.kwargs...)
196205
end
197-
return cacheval
206+
return iterable
198207
end
199208

200209
function SciMLBase.solve(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
201210
if cache.isfresh
202-
solver = init_cacheval(alg, cache.A, cache.b, cache.u,
203-
cache.Pl, cache.Pr)
211+
solver = init_cacheval(alg, cache.A, cache.b, cache.u)
204212
cache = set_cacheval(cache, solver)
205213
end
206214

207215
for resi in cache.cacheval
208216
# allow for verbose, log
209-
# inject specific code into KSP solve
217+
# inject specific code into KSP solve func!(cache.cacheval)
210218
end
211219

212220
return cache.u

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ using Test
2323
# :DefaultLinSolve,
2424

2525
:KrylovJL, :KrylovJL_CG, :KrylovJL_GMRES, :KrylovJL_BICGSTAB,
26+
:KrylovJL_MINRES,
2627
:IterativeSolversJL, :IterativeSolversJL_GMRES,
27-
:IterativeSolversJL_BICGSTAB,
28+
:IterativeSolversJL_BICGSTAB, :IterativeSolversJL_MINRES
2829
)
2930
@eval begin
3031
y = solve($prob1, $alg())

0 commit comments

Comments
 (0)