Skip to content

Commit 114e827

Browse files
Merge pull request #30 from vpuri3/vp/common-args
common arg handling
2 parents 3b9379b + 641d11e commit 114e827

File tree

3 files changed

+51
-43
lines changed

3 files changed

+51
-43
lines changed

src/common.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tl,Tr}
1+
struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tl,Tr,Ttol}
22
A::TA
33
b::Tb
44
u::Tu
@@ -8,6 +8,10 @@ struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tl,Tr}
88
isfresh::Bool # false => cacheval is set wrt A, true => update cacheval wrt A
99
Pl::Tl # store final preconditioner here. not being used rn
1010
Pr::Tr # wrappers are using preconditioner in cache.alg for now
11+
abstol::Ttol
12+
reltol::Ttol
13+
maxiters::Int
14+
verbose::Bool
1115
end
1216

1317
function set_A(cache::LinearCache, A)
@@ -46,6 +50,10 @@ SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,no
4650

4751
function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing}, args...;
4852
alias_A = false, alias_b = false,
53+
abstol=eps(eltype(prob.A)),
54+
reltol=eps(eltype(prob.A)),
55+
maxiters=length(prob.b),
56+
verbose=false,
4957
kwargs...,
5058
)
5159
@unpack A, b, u0, p = prob
@@ -71,6 +79,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
7179
Tc,
7280
typeof(Pl),
7381
typeof(Pr),
82+
typeof(reltol),
7483
}(
7584
A,
7685
b,
@@ -81,6 +90,10 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
8190
isfresh,
8291
Pl,
8392
Pr,
93+
abstol,
94+
reltol,
95+
maxiters,
96+
verbose,
8497
)
8598
return cache
8699
end

src/wrappers.jl

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,21 @@ end
3131

3232
## Krylov.jl
3333

34-
struct KrylovJL{F,Tl,Tr,T,I,A,K} <: AbstractKrylovSubspaceMethod
34+
struct KrylovJL{F,Tl,Tr,I,A,K} <: AbstractKrylovSubspaceMethod
3535
KrylovAlg::F
3636
Pl::Tl
3737
Pr::Tr
38-
abstol::T
39-
reltol::T
40-
maxiter::I
41-
ifverbose::Bool
4238
gmres_restart::I
4339
window::I
4440
args::A
4541
kwargs::K
4642
end
4743

4844
function KrylovJL(args...; KrylovAlg = Krylov.gmres!, Pl=I, Pr=I,
49-
abstol=0.0, reltol=0.0, maxiter=0, ifverbose=false,
50-
gmres_restart=20, window=0, # for building solver
45+
gmres_restart=0, window=0,
5146
kwargs...)
5247

53-
return KrylovJL(KrylovAlg, Pl, Pr, abstol, reltol, maxiter, ifverbose,
54-
gmres_restart, window,
48+
return KrylovJL(KrylovAlg, Pl, Pr, gmres_restart, window,
5549
args, kwargs)
5650
end
5751

@@ -106,13 +100,15 @@ function init_cacheval(alg::KrylovJL, A, b, u)
106100

107101
KS = get_KrylovJL_solver(alg.KrylovAlg)
108102

103+
memory = (alg.gmres_restart == 0) ? min(20, size(A,1)) : alg.gmres_restart
104+
109105
solver = if(
110106
alg.KrylovAlg === Krylov.dqgmres! ||
111107
alg.KrylovAlg === Krylov.diom! ||
112108
alg.KrylovAlg === Krylov.gmres! ||
113109
alg.KrylovAlg === Krylov.fom!
114110
)
115-
KS(A, b, alg.gmres_restart)
111+
KS(A, b, memory)
116112
elseif(
117113
alg.KrylovAlg === Krylov.minres! ||
118114
alg.KrylovAlg === Krylov.symmlq! ||
@@ -136,13 +132,13 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
136132
cache = set_cacheval(cache, solver)
137133
end
138134

139-
abstol = (alg.abstol == 0) ? eps(eltype(cache.b)) : alg.abstol
140-
reltol = (alg.reltol == 0) ? eps(eltype(cache.b)) : alg.reltol
141-
maxiter = (alg.maxiter == 0) ? length(cache.b) : alg.maxiter
142-
verbose = alg.ifverbose ? 1 : 0
135+
atol = cache.abstol
136+
rtol = cache.reltol
137+
itmax = cache.maxiters
138+
verbose = cache.verbose ? 1 : 0
143139

144140
args = (cache.cacheval, cache.A, cache.b)
145-
kwargs = (atol=abstol, rtol=reltol, itmax=maxiter, verbose=verbose,
141+
kwargs = (atol=atol, rtol=rtol, itmax=itmax, verbose=verbose,
146142
alg.kwargs...)
147143

148144
if cache.cacheval isa Krylov.CgSolver
@@ -170,14 +166,10 @@ end
170166

171167
## IterativeSolvers.jl
172168

173-
struct IterativeSolversJL{F,Tl,Tr,T,I,A,K} <: AbstractKrylovSubspaceMethod
169+
struct IterativeSolversJL{F,Tl,Tr,I,A,K} <: AbstractKrylovSubspaceMethod
174170
generate_iterator::F
175171
Pl::Tl
176172
Pr::Tr
177-
abstol::T
178-
reltol::T
179-
maxiter::I
180-
ifverbose::Bool
181173
gmres_restart::I
182174
args::A
183175
kwargs::K
@@ -187,11 +179,9 @@ function IterativeSolversJL(args...;
187179
generate_iterator = IterativeSolvers.gmres_iterable!,
188180
Pl=IterativeSolvers.Identity(),
189181
Pr=IterativeSolvers.Identity(),
190-
abstol=0.0, reltol=0.0, maxiter=0, ifverbose=true,
191182
gmres_restart=0, kwargs...)
192-
return IterativeSolversJL(generate_iterator, Pl, Pr,
193-
abstol, reltol, maxiter, ifverbose,
194-
gmres_restart, args, kwargs)
183+
return IterativeSolversJL(generate_iterator, Pl, Pr, gmres_restart,
184+
args, kwargs)
195185
end
196186

197187
IterativeSolversJL_CG(args...; kwargs...) =
@@ -211,24 +201,29 @@ IterativeSolversJL_MINRES(args...;kwargs...) =
211201
generate_iterator=IterativeSolvers.minres_iterable!,
212202
kwargs...)
213203

214-
function init_cacheval(alg::IterativeSolversJL, A, b, u)
204+
function init_cacheval(alg::IterativeSolversJL, cache::LinearCache)
205+
@unpack A, b, u = cache
206+
215207
Pl = (alg.Pl == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pl
216208
Pr = (alg.Pr == LinearAlgebra.I) ? IterativeSolvers.Identity() : alg.Pr
217209

218-
abstol = (alg.abstol == 0) ? eps(eltype(b)) : alg.abstol
219-
reltol = (alg.reltol == 0) ? eps(eltype(b)) : alg.reltol
220-
maxiter = (alg.maxiter == 0) ? length(b) : alg.maxiter
210+
abstol = cache.abstol
211+
reltol = cache.reltol
212+
maxiter = cache.maxiters
213+
verbose = cache.verbose
221214

222-
# args = (u, A, b)
223-
kwargs = (abstol=abstol, reltol=reltol, maxiter=maxiter, alg.kwargs...)
215+
restart = (alg.gmres_restart == 0) ? min(20, size(A,1)) : alg.gmres_restart
216+
217+
kwargs = (abstol=abstol, reltol=reltol, maxiter=maxiter,
218+
alg.kwargs...)
224219

225220
iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
226221
Pr != IterativeSolvers.Identity() &&
227222
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
228223
alg.generate_iterator(u, A, b, Pl;
229224
kwargs...)
230225
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
231-
alg.generate_iterator(u, A, b; Pl=Pl, Pr=Pr,
226+
alg.generate_iterator(u, A, b; Pl=Pl, Pr=Pr, restart=restart,
232227
kwargs...)
233228
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
234229
Pr != IterativeSolvers.Identity() &&
@@ -247,16 +242,16 @@ end
247242

248243
function SciMLBase.solve(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
249244
if cache.isfresh
250-
solver = init_cacheval(alg, cache.A, cache.b, cache.u)
245+
solver = init_cacheval(alg, cache)
251246
cache = set_cacheval(cache, solver)
252247
end
253248

254-
alg.ifverbose && println("Using IterativeSolvers.$(alg.generate_iterator)")
249+
cache.verbose && println("Using IterativeSolvers.$(alg.generate_iterator)")
255250
for iter in enumerate(cache.cacheval)
256-
alg.ifverbose && println("Iter: $(iter[1]), residual: $(iter[2])")
257-
# inject callbacks KSP into solve cb!(cache.cacheval)
251+
cache.verbose && println("Iter: $(iter[1]), residual: $(iter[2])")
252+
# TODO inject callbacks KSP into solve cb!(cache.cacheval)
258253
end
259-
alg.ifverbose && println()
254+
cache.verbose && println()
260255

261256
return cache.u
262257
end

test/runtests.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@ A2 = A/2; b2 = rand(n); x2 = zero(b)
1010
prob1 = LinearProblem(A1, b1; u0=x1)
1111
prob2 = LinearProblem(A2, b2; u0=x2)
1212

13+
cache_kwargs = (;verbose=true, abstol=1e-8, reltol=1e-8, maxiter=30,)
14+
1315
function test_interface(alg, prob1, prob2)
1416
A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0
1517
A2 = prob2.A; b2 = prob2.b; x2 = prob2.u0
1618

17-
y = solve(prob1, alg)
19+
y = solve(prob1, alg; cache_kwargs...)
1820
@test A1 * y b1
1921

20-
cache = SciMLBase.init(prob1,alg) # initialize cache
22+
cache = SciMLBase.init(prob1,alg; cache_kwargs...) # initialize cache
2123
y = solve(cache)
2224
@test A1 * y b1
2325

@@ -92,8 +94,7 @@ end
9294
end
9395

9496
@testset "KrylovJL" begin
95-
kwargs = (;ifverbose=false, abstol=1e-8, reltol=1e-8, maxiter=30,
96-
gmres_restart=5)
97+
kwargs = (;gmres_restart=5,)
9798
for alg in (
9899
("Default",KrylovJL(kwargs...)),
99100
("CG",KrylovJL_CG(kwargs...)),
@@ -108,8 +109,7 @@ end
108109
end
109110

110111
@testset "IterativeSolversJL" begin
111-
kwargs = (;ifverbose=false, abstol=1e-8, reltol=1e-8, maxiter=30,
112-
gmres_restart=5)
112+
kwargs = (;gmres_restart=5,)
113113
for alg in (
114114
("Default", IterativeSolversJL(kwargs...)),
115115
("CG", IterativeSolversJL_CG(kwargs...)),

0 commit comments

Comments
 (0)