Skip to content

Commit 3b1ef17

Browse files
committed
init_cacheval
1 parent 9d1def4 commit 3b1ef17

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

src/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ function set_cacheval(cache, alg_cache)
4040
return cache
4141
end
4242

43-
init_cacheval(A, alg::SciMLLinearSolveAlgorithm) = nothing
43+
init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
4444

4545
function SciMLBase.init(prob::LinearProblem, alg, args...;
4646
alias_A = false, alias_b = false,
@@ -52,7 +52,7 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
5252
u0 = zero(b)
5353
end
5454

55-
cacheval = init_cacheval(prob.A, alg)
55+
cacheval = init_cacheval(alg, A, b, u0)
5656
Tc = cacheval == nothing ? Any : typeof(cacheval)
5757
isfresh = cacheval == nothing
5858

src/wrappers.jl

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,40 @@
11
## Krylov.jl
22

33
struct KrylovJL{F,A,K} <: SciMLLinearSolveAlgorithm
4-
solver::F
4+
KrylovAlg::F
55
args::A
66
kwargs::K
77
end
88

9-
function KrylovJL(args...; solver = Krylov.gmres, kwargs...)
10-
return KrylovJL(solver, args, kwargs)
9+
function KrylovJL(args...; KrylovAlg = Krylov.gmres!, kwargs...)
10+
return KrylovJL(KrylovAlg, args, kwargs)
1111
end
1212

13-
# place Krylov.CGsolver in LinearCache.cacheval for reuse
14-
function init_cacheval(prob::LinearProblem, alg::KrylovJL)
15-
if alg.solver === Krylov.cg!
16-
elseif alg.solver === Krylov.gmres!
17-
elseif alg.solver === Krylov.bicgstab!
13+
function init_cacheval(alg::KrylovJL, A, b, u)
14+
cacheval = if alg.KrylovAlg === Krylov.cg!
15+
CgSolver(A,b)
16+
elseif alg.KrylovAlg === Krylov.gmres!
17+
GmresSolver(A,b,20)
18+
elseif alg.KrylovAlg === Krylov.bicgstab!
19+
BicgstabSolver(A,b)
20+
else
21+
nothing
1822
end
19-
return
23+
return cacheval
2024
end
2125

22-
# KrylovJL failing in-place
23-
function SciMLBase.solve(cache::LinearCache, alg::KrylovJL,args...;kwargs...)
24-
@unpack A, b, u, Pr, Pl = cache
25-
u, stats = alg.solver(A, b, args...; M=Pl, N=Pr, kwargs...)
26-
resid = A * u - b
27-
retcode = stats.solved ? :Success : :Failure
28-
return u
26+
function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
27+
@unpack A, b, u, Pr, Pl, cacheval = cache
28+
29+
if cache.isfresh
30+
solver = init_cacheval(alg.KrylovAlg, A, b, u)
31+
solver.x = u
32+
cache = set_cacheval(cache, solver)
33+
end
34+
35+
alg.solver(cacheval, A, b; M=Pl, N=Pr, alg.kwargs...)
36+
37+
return cache.u
2938
end
3039

3140
KrylovJL_CG(args...;kwargs...) = KrylovJL(Krylov.cg!, args...; kwargs...)

0 commit comments

Comments
 (0)