Skip to content

Commit 580897b

Browse files
committed
iterativesolvers.jl wrapper
1 parent 0ac1a80 commit 580897b

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

src/wrappers.jl

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ function KrylovJL(args...; KrylovAlg = Krylov.gmres!, kwargs...)
1010
return KrylovJL(KrylovAlg, args, kwargs)
1111
end
1212

13+
KrylovJL_CG(args...;kwargs...) = KrylovJL(Krylov.cg!, args...; kwargs...)
14+
KrylovJL_GMRES(args...;kwargs...) = KrylovJL(Krylov.gmres!, args...; kwargs...)
15+
KrylovJL_BICGSTAB(args...;kwargs...) = KrylovJL(Krylov.bicgstab!, args...; kwargs...)
16+
1317
function init_cacheval(alg::KrylovJL, A, b, u)
1418
cacheval = if alg.KrylovAlg === Krylov.cg!
1519
Krylov.CgSolver(A,b)
@@ -36,34 +40,48 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
3640
return cache.u
3741
end
3842

39-
KrylovJL_CG(args...;kwargs...) = KrylovJL(Krylov.cg!, args...; kwargs...)
40-
KrylovJL_GMRES(args...;kwargs...) = KrylovJL(Krylov.gmres!, args...; kwargs...)
41-
KrylovJL_BICGSTAB(args...;kwargs...) = KrylovJL(Krylov.bicgstab!, args...; kwargs...)
42-
4343
## IterativeSolvers.jl
4444

4545
struct IterativeSolversJL{F,A,K} <: SciMLLinearSolveAlgorithm
46-
solver::F
46+
generate_iterator::F
4747
args::A
4848
kwargs::K
4949
end
5050

51-
## KrylovKit.jl
52-
53-
struct KrylovKitJL{F,A,K} <: SciMLLinearSolveAlgorithm
54-
solver::F
55-
args::A
56-
kwargs::K
51+
function IterativeSolversJL(args...;
52+
generate_iterator = IterativeSolvers.gmres_iterable!,
53+
kwargs...)
54+
return IterativeSolversJL(generate_iterator, args, kwargs)
5755
end
5856

59-
function KrylovKitJL(args...; solver = KrylovKit.CG(), kwargs...)
60-
return KrylovKitJL(solver, args, kwargs)
57+
#IterativeSolversJL_CG(args...; kwargs...)
58+
# = IterativeSolversJL(IterativeSolvers.cg_iterator!, args...; kwargs...)
59+
#IterativeSolversJL_GMRES(args...;kwargs...)
60+
# = IterativeSolversJL(IterativeSolvers.gmres_iterable!, args...; kwargs...)
61+
#IterativeSolversJL_BICGSTAB(args...;kwargs...)
62+
# = IterativeSolversJL(IterativeSolvers.bicgstabl_iterator!, args...;kwargs...)
63+
64+
function init_cacheval(alg::IterativeSolversJL, A, b, u)
65+
cacheval = if alg.generate_iterator === IterativeSolvers.cg_iterator!
66+
alg.generate_iterator(u, A, b)
67+
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
68+
alg.generate_iterator(u, A, b)
69+
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
70+
alg.generate_iterator(u, A, b)
71+
else
72+
alg.generate_iterator(u, A, b)
73+
end
74+
return cacheval
6175
end
6276

63-
function SciMLBase.solve(cache::LinearCache, alg::KrylovKitJL,args...;kwargs...)
64-
@unpack A, b, u = cache
65-
@unpack solver = alg
66-
u = KrylovKit.linsolve(A, b, u, solver, args...; kwargs...)[1] #no precond?!
67-
return u
77+
function SciMLBase.solve(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
78+
if cache.isfresh
79+
solver = init_cacheval(alg, cache.A, cache.b, cache.u)
80+
cache = set_cacheval(cache, solver)
81+
end
82+
83+
for resi in cache.cacheval end
84+
85+
return cache.u
6886
end
6987

test/runtests.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ using Test
33

44
@testset "LinearSolve.jl" begin
55
using LinearAlgebra
6-
n = 2
6+
n = 32
77

88
A = Matrix(I,n,n)
99
b = ones(n)
10-
A1 = A/1; b1 = ones(n); x1 = zero(b)
11-
A2 = A/2; b2 = ones(n); x2 = zero(b)
12-
A3 = A/3; b3 = ones(n); x3 = zero(b)
10+
A1 = A/1; b1 = rand(n); x1 = zero(b)
11+
A2 = A/2; b2 = rand(n); x2 = zero(b)
12+
A3 = A/3; b3 = rand(n); x3 = zero(b)
1313

1414
prob1 = LinearProblem(A1, b1; u0=x1)
1515
prob2 = LinearProblem(A2, b2; u0=x2)
@@ -23,9 +23,7 @@ using Test
2323
# :DefaultLinSolve,
2424

2525
:KrylovJL, :KrylovJL_CG, :KrylovJL_GMRES, :KrylovJL_BICGSTAB,
26-
# :IterativeSolversJL
27-
# :KrylovKitJL,
28-
26+
:IterativeSolversJL,#:IterativeSolversJL_GMRES, :IterativeSolversJL_BICGSTAB,
2927
)
3028
@eval begin
3129
y = solve($prob1, $alg())

0 commit comments

Comments
 (0)