Skip to content

Commit 308b3dd

Browse files
committed
KrylovJL wrapper almost done
1 parent 580897b commit 308b3dd

File tree

5 files changed

+173
-35
lines changed

5 files changed

+173
-35
lines changed

src/LinearSolve.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ include("wrappers.jl")
2525

2626
export DefaultLinSolve
2727
export LUFactorization, SVDFactorization, QRFactorization
28-
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB
29-
export IterativeSolversJL #, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
30-
#IterativeSolversJL_BICGSTAB
31-
export KrylovKitJL #, KrylovKitJL_CG, KrylovKitJL_GMRES, KrylovKitJL_BICGSTAB
28+
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB,
29+
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
30+
IterativeSolversJL_BICGSTAB
3231

3332
end

src/common.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ function (cache::LinearCache)(prob::LinearProblem, args...; kwargs...)
121121
end
122122

123123
function (cache::LinearCache)(x, A, b, args...; kwargs...)
124-
125124
prob = LinearProblem(A, b; u0=x)
126125
x = cache(prob, args...; kwargs...)
127126
return x

src/factorization.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,15 @@ function init_cacheval(alg::SVDFactorization, A, b, u)
7272
return fact
7373
end
7474

75+
## DefaultFactorization
76+
77+
struct DefaultFactorization <: AbstractFactorization
78+
end
79+
80+
function init_cacheval(alg::DefaultFactorization, A, b, u)
81+
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
82+
error("DefaultFactorization is not defined for $(typeof(A))")
83+
84+
fact = factorize(A)
85+
return fact
86+
end

src/wrappers.jl

Lines changed: 156 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,127 @@
1+
2+
"""
3+
TODO
4+
- account for standard kwargs: abstol, reltol, maxiter, restart, window
5+
- KrylovJL: memory, window <-- Solver, itmax, atol, rtol <-- Solve
6+
- IterativeSolversJL: restart
7+
"""
8+
9+
10+
## Preconditioners
11+
12+
struct ScaleVector{T}
13+
s::T
14+
isleft::Bool
15+
end
16+
17+
function LinearAlgebra.ldiv!(v::ScaleVector, x)
18+
end
19+
20+
function LinearAlgebra.ldiv!(y, v::ScaleVector, x)
21+
end
22+
23+
struct ComposePreconditioner{Ti,To}
24+
inner::Ti
25+
outer::To
26+
isleft::Bool
27+
end
28+
29+
function LinearAlgebra.ldiv!(v::ComposePreconditioner, x)
30+
@unpack inner, outer, isleft = v
31+
end
32+
33+
function LinearAlgebra.ldiv!(y, v::ComposePreconditioner, x)
34+
@unpack inner, outer, isleft = v
35+
end
36+
137
## Krylov.jl
238

3-
struct KrylovJL{F,A,K} <: SciMLLinearSolveAlgorithm
39+
struct KrylovJL{F,Tl,Tr,A,K, T, I} <: SciMLLinearSolveAlgorithm
440
KrylovAlg::F
41+
Pl::Tl
42+
Pr::Tr
43+
abstol::T
44+
reltol::T
45+
maxiter::I
46+
restart::I
47+
window::I
548
args::A
649
kwargs::K
750
end
851

9-
function KrylovJL(args...; KrylovAlg = Krylov.gmres!, kwargs...)
10-
return KrylovJL(KrylovAlg, args, kwargs)
52+
function KrylovJL(args...; KrylovAlg = Krylov.gmres!, Pl=I, Pr=I,
53+
abstol=0.0, reltol=0.0, maxiter=0, # for solver call
54+
restart=20, window=0, # for building solver
55+
kwargs...)
56+
57+
return KrylovJL(KrylovAlg, Pl, Pr, abstol, reltol, maxiter,
58+
restart, window,
59+
args, kwargs)
1160
end
1261

1362
KrylovJL_CG(args...;kwargs...) = KrylovJL(Krylov.cg!, args...; kwargs...)
1463
KrylovJL_GMRES(args...;kwargs...) = KrylovJL(Krylov.gmres!, args...; kwargs...)
1564
KrylovJL_BICGSTAB(args...;kwargs...) = KrylovJL(Krylov.bicgstab!, args...; kwargs...)
1665

66+
const KrylovJL_solvers = Dict(
67+
(Krylov.lsmr! => Krylov.LsmrSolver ),
68+
(Krylov.cgs! => Krylov.CgsSolver ),
69+
(Krylov.usymlq! => Krylov.UsymlqSolver ),
70+
(Krylov.lnlq! => Krylov.LnlqSolver ),
71+
(Krylov.bicgstab! => Krylov.BicgstabSolver ),
72+
(Krylov.crls! => Krylov.CrlsSolver ),
73+
(Krylov.lsqr! => Krylov.LsqrSolver ),
74+
(Krylov.minres! => Krylov.MinresSolver ),
75+
(Krylov.cgne! => Krylov.CgneSolver ),
76+
(Krylov.dqgmres! => Krylov.DqgmresSolver ),
77+
(Krylov.symmlq! => Krylov.SymmlqSolver ),
78+
(Krylov.trimr! => Krylov.TrimrSolver ),
79+
(Krylov.usymqr! => Krylov.UsymqrSolver ),
80+
(Krylov.bilqr! => Krylov.BilqrSolver ),
81+
(Krylov.cr! => Krylov.CrSolver ),
82+
(Krylov.craigmr! => Krylov.CraigmrSolver ),
83+
(Krylov.tricg! => Krylov.TricgSolver ),
84+
(Krylov.craig! => Krylov.CraigSolver ),
85+
(Krylov.diom! => Krylov.DiomSolver ),
86+
(Krylov.lslq! => Krylov.LslqSolver ),
87+
(Krylov.trilqr! => Krylov.TrilqrSolver ),
88+
(Krylov.crmr! => Krylov.CrmrSolver ),
89+
(Krylov.cg! => Krylov.CgSolver ),
90+
(Krylov.cg_lanczos! => Krylov.CgLanczosShiftSolver),
91+
(Krylov.cgls! => Krylov.CglsSolver ),
92+
(Krylov.cg_lanczos! => Krylov.CgLanczosSolver ),
93+
(Krylov.bilq! => Krylov.BilqSolver ),
94+
(Krylov.minres_qlp! => Krylov.MinresQlpSolver ),
95+
(Krylov.qmr! => Krylov.QmrSolver ),
96+
(Krylov.gmres! => Krylov.GmresSolver ),
97+
(Krylov.fom! => Krylov.FomSolver ),
98+
)
99+
17100
function init_cacheval(alg::KrylovJL, A, b, u)
18-
cacheval = if alg.KrylovAlg === Krylov.cg!
19-
Krylov.CgSolver(A,b)
20-
elseif alg.KrylovAlg === Krylov.gmres!
21-
Krylov.GmresSolver(A,b,20)
22-
elseif alg.KrylovAlg === Krylov.bicgstab!
23-
Krylov.BicgstabSolver(A,b)
101+
102+
KS = KrylovJL_solvers[alg.KrylovAlg]
103+
104+
solver =
105+
if (KS === Krylov.DqgmresSolver ||
106+
KS === Krylov.DiomSolver ||
107+
KS === Krylov.GmresSolver ||
108+
KS === Krylov.FormSovler
109+
)
110+
KS(A, b, alg.restart)
111+
elseif(KS === Krylov.MinresSolver ||
112+
KS === Krylov.SymmlqSolver ||
113+
KS === Krylov.LslqSolver ||
114+
KS === Krylov.LsqrSolver ||
115+
KS === Krylov.LsmrSolver
116+
)
117+
(alg.window != 0) ? KS(A,b; window=alg.window) : KS(A, b)
24118
else
25-
nothing
119+
KS(A, b)
26120
end
27-
return cacheval
121+
122+
solver.x = u
123+
124+
return solver
28125
end
29126

30127
function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
@@ -33,54 +130,84 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
33130
cache = set_cacheval(cache, solver)
34131
end
35132

36-
cache.cacheval.x = cache.u
37-
alg.KrylovAlg(cache.cacheval, cache.A, cache.b;
38-
M=cache.Pl, N=cache.Pr, alg.kwargs...)
133+
abstol = (alg.abstol == 0) ? eps(eltype(cache.b)) : alg.abstol
134+
reltol = (alg.reltol == 0) ? eps(eltype(cache.b)) : alg.reltol
135+
maxiter = (alg.reltol == 0) ? length(cache.b) : alg.maxiter
136+
137+
Krylov.solve!(cache.cacheval, cache.A, cache.b;
138+
M=cache.Pl, N=cache.Pr,
139+
atol = abstol, rtol = reltol, itmax = maxiter,
140+
alg.kwargs...)
39141

40142
return cache.u
41143
end
42144

43145
## IterativeSolvers.jl
44146

45-
struct IterativeSolversJL{F,A,K} <: SciMLLinearSolveAlgorithm
147+
struct IterativeSolversJL{F,Tl,Tr,A,K} <: SciMLLinearSolveAlgorithm
46148
generate_iterator::F
149+
Pl::Tl
150+
Pr::Tr
47151
args::A
48152
kwargs::K
153+
# abstol::T
154+
# reltol::T
155+
# maxiter::I
156+
# restart::I
49157
end
50158

51159
function IterativeSolversJL(args...;
160+
Pl=IterativeSolvers.Identity(),
161+
Pr=IterativeSolvers.Identity(),
52162
generate_iterator = IterativeSolvers.gmres_iterable!,
53163
kwargs...)
54-
return IterativeSolversJL(generate_iterator, args, kwargs)
164+
return IterativeSolversJL(generate_iterator, Pl, Pr, args, kwargs)
55165
end
56166

57-
#IterativeSolversJL_CG(args...; kwargs...)
58-
# = IterativeSolversJL(IterativeSolvers.cg_iterator!, args...; kwargs...)
59-
#IterativeSolversJL_GMRES(args...;kwargs...)
60-
# = IterativeSolversJL(IterativeSolvers.gmres_iterable!, args...; kwargs...)
61-
#IterativeSolversJL_BICGSTAB(args...;kwargs...)
62-
# = IterativeSolversJL(IterativeSolvers.bicgstabl_iterator!, args...;kwargs...)
167+
IterativeSolversJL_CG(args...; kwargs...) =
168+
IterativeSolversJL(IterativeSolvers.cg_iterator!, args...; kwargs...)
169+
IterativeSolversJL_GMRES(args...;kwargs...) =
170+
IterativeSolversJL(IterativeSolvers.gmres_iterable!, args...; kwargs...)
171+
IterativeSolversJL_BICGSTAB(args...;kwargs...) =
172+
IterativeSolversJL(IterativeSolvers.bicgstabl_iterator!, args...;kwargs...)
173+
174+
function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr)
175+
Pl = (Pl == LinearAlgebra.I) ? IterativeSolvers.Identity() : Pl
176+
Pr = (Pr == LinearAlgebra.I) ? IterativeSolvers.Identity() : Pr
177+
178+
# standard kwargs: abstol, reltol
179+
#
180+
# cg: kw: maxiter
181+
# gmres: kw: restart=20, maxiter=length(b)
182+
# bicgstabl: kw: max_mv_products
63183

64-
function init_cacheval(alg::IterativeSolversJL, A, b, u)
65184
cacheval = if alg.generate_iterator === IterativeSolvers.cg_iterator!
66-
alg.generate_iterator(u, A, b)
185+
Pr != IterativeSolvers.Identity() &&
186+
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
187+
alg.generate_iterator(u, A, b, Pl; alg.kwargs...)
67188
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
68-
alg.generate_iterator(u, A, b)
189+
alg.generate_iterator(u, A, b; Pl=Pl, Pr=Pr, alg.kwargs...)
69190
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
70-
alg.generate_iterator(u, A, b)
191+
Pr != IterativeSolvers.Identity() &&
192+
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
193+
alg.generate_iterator(u, A, b, alg.args...; Pl=Pl, alg.kwargs...)
71194
else
72-
alg.generate_iterator(u, A, b)
195+
alg.generate_iterator(u, A, b, alg.args...; alg.kwargs...)
73196
end
74197
return cacheval
75198
end
76199

77200
function SciMLBase.solve(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
78201
if cache.isfresh
79-
solver = init_cacheval(alg, cache.A, cache.b, cache.u)
202+
solver = init_cacheval(alg, cache.A, cache.b, cache.u,
203+
cache.Pl, cache.Pr)
80204
cache = set_cacheval(cache, solver)
81205
end
82206

83-
for resi in cache.cacheval end
207+
for resi in cache.cacheval
208+
# allow for verbose, log
209+
# inject specific code into KSP solve
210+
end
84211

85212
return cache.u
86213
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ using Test
2323
# :DefaultLinSolve,
2424

2525
:KrylovJL, :KrylovJL_CG, :KrylovJL_GMRES, :KrylovJL_BICGSTAB,
26-
:IterativeSolversJL,#:IterativeSolversJL_GMRES, :IterativeSolversJL_BICGSTAB,
26+
:IterativeSolversJL, :IterativeSolversJL_GMRES,
27+
:IterativeSolversJL_BICGSTAB,
2728
)
2829
@eval begin
2930
y = solve($prob1, $alg())

0 commit comments

Comments
 (0)