Skip to content

Commit 0ac1a80

Browse files
committed
krylov.jl wrapper working
1 parent e0a2408 commit 0ac1a80

File tree

4 files changed

+21
-22
lines changed

4 files changed

+21
-22
lines changed

src/LinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ include("wrappers.jl")
2525

2626
export DefaultLinSolve
2727
export LUFactorization, SVDFactorization, QRFactorization
28-
export KrylovJL #, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB
28+
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB
2929
export IterativeSolversJL #, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
3030
#IterativeSolversJL_BICGSTAB
3131
export KrylovKitJL #, KrylovKitJL_CG, KrylovKitJL_GMRES, KrylovKitJL_BICGSTAB

src/common.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,29 @@ struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tl,Tr}
1010
Pr::Tr
1111
end
1212

13-
function set_A(cache, A) # and ! to function name
13+
function set_A(cache::LinearCache, A) # and ! to function name
1414
@set! cache.A = A
1515
@set! cache.isfresh = true
1616
return cache
1717
end
1818

19-
function set_b(cache, b)
19+
function set_b(cache::LinearCache, b)
2020
@set! cache.b = b
2121
return cache
2222
end
2323

24-
function set_u(cache, u)
24+
function set_u(cache::LinearCache, u)
2525
@set! cache.u = u
2626
return cache
2727
end
2828

29-
function set_p(cache, p)
29+
function set_p(cache::LinearCache, p)
3030
@set! cache.p = p
3131
# @set! cache.isfresh = true
3232
return cache
3333
end
3434

35-
function set_cacheval(cache, alg_cache)
35+
function set_cacheval(cache::LinearCache, alg_cache)
3636
if cache.isfresh
3737
@set! cache.cacheval = alg_cache
3838
@set! cache.isfresh = false
@@ -53,8 +53,8 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
5353
end
5454

5555
cacheval = init_cacheval(alg, A, b, u0)
56-
Tc = cacheval == nothing ? Any : typeof(cacheval)
5756
isfresh = cacheval == nothing
57+
Tc = isfresh ? Any : typeof(cacheval)
5858

5959
Pl = LinearAlgebra.I
6060
Pr = LinearAlgebra.I
@@ -108,8 +108,8 @@ end
108108

109109
function (cache::LinearCache)(prob::LinearProblem, args...; kwargs...)
110110

111-
if(prob.A != cache.A) cache = set_A(cache, prob.A) end
112-
if(prob.b != cache.b) cache = set_b(cache, prob.b) end
111+
if(prob.A != cache.A) cache = set_A(cache, prob.A) end
112+
if(prob.b != cache.b) cache = set_b(cache, prob.b) end
113113

114114
if(prob.u0 == nothing)
115115
prob.u0 = zero(x)

src/wrappers.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,26 @@ end
1212

1313
function init_cacheval(alg::KrylovJL, A, b, u)
1414
cacheval = if alg.KrylovAlg === Krylov.cg!
15-
CgSolver(A,b)
15+
Krylov.CgSolver(A,b)
1616
elseif alg.KrylovAlg === Krylov.gmres!
17-
GmresSolver(A,b,20)
17+
Krylov.GmresSolver(A,b,20)
1818
elseif alg.KrylovAlg === Krylov.bicgstab!
19-
BicgstabSolver(A,b)
19+
Krylov.BicgstabSolver(A,b)
2020
else
2121
nothing
2222
end
2323
return cacheval
2424
end
2525

2626
function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
27-
@unpack A, b, u, Pr, Pl, cacheval = cache
28-
2927
if cache.isfresh
30-
solver = init_cacheval(alg.KrylovAlg, A, b, u)
31-
solver.x = u
28+
solver = init_cacheval(alg, cache.A, cache.b, cache.u)
3229
cache = set_cacheval(cache, solver)
3330
end
3431

35-
alg.solver(cacheval, A, b; M=Pl, N=Pr, alg.kwargs...)
32+
cache.cacheval.x = cache.u
33+
alg.KrylovAlg(cache.cacheval, cache.A, cache.b;
34+
M=cache.Pl, N=cache.Pr, alg.kwargs...)
3635

3736
return cache.u
3837
end

test/runtests.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ using Test
33

44
@testset "LinearSolve.jl" begin
55
using LinearAlgebra
6-
n = 8
6+
n = 2
77

88
A = Matrix(I,n,n)
99
b = ones(n)
10-
A1 = A/1; b1 = rand(n); x1 = zero(b)
11-
A2 = A/2; b2 = rand(n); x2 = zero(b)
12-
A3 = A/3; b3 = rand(n); x3 = zero(b)
10+
A1 = A/1; b1 = ones(n); x1 = zero(b)
11+
A2 = A/2; b2 = ones(n); x2 = zero(b)
12+
A3 = A/3; b3 = ones(n); x3 = zero(b)
1313

1414
prob1 = LinearProblem(A1, b1; u0=x1)
1515
prob2 = LinearProblem(A2, b2; u0=x2)
@@ -22,7 +22,7 @@ using Test
2222

2323
# :DefaultLinSolve,
2424

25-
# :KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB,
25+
:KrylovJL, :KrylovJL_CG, :KrylovJL_GMRES, :KrylovJL_BICGSTAB,
2626
# :IterativeSolversJL
2727
# :KrylovKitJL,
2828

0 commit comments

Comments
 (0)