Skip to content

Commit 623e45c

Browse files
committed
reuse cache for multiple solves
1 parent d008964 commit 623e45c

File tree

2 files changed

+42
-15
lines changed

2 files changed

+42
-15
lines changed

src/common.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tr,Tl}
1111
# k::Tk # iteration count
1212
end
1313

14-
function set_A(cache, A)
14+
function set_A(cache, A) # and ! to function name
1515
@set! cache.A = A
1616
@set! cache.isfresh = true
1717
end
@@ -20,6 +20,10 @@ function set_b(cache, b)
2020
@set! cache.b = b
2121
end
2222

23+
function set_u(cache, u)
24+
@set! cache.u = u
25+
end
26+
2327
function set_p(cache, p)
2428
@set! cache.p = p
2529
# @set! cache.isfresh = true
@@ -57,10 +61,8 @@ function SciMLBase.init(
5761
Pr = LinearAlgebra.I
5862
Pl = LinearAlgebra.I
5963

60-
# @show (A, b, u0, p) |> typeof
61-
62-
A = alias_A ? A : copy(A)
63-
b = alias_b ? b : copy(b)
64+
A = alias_A ? A : deepcopy(A)
65+
b = alias_b ? b : deepcopy(b)
6466

6567
cache = LinearCache{
6668
typeof(A),
@@ -95,3 +97,21 @@ function (alg::SciMLLinearSolveAlgorithm)(x,A,b,args...;u0=nothing,kwargs...)
9597
x = solve(prob,alg,args...;kwargs...)
9698
return x
9799
end
100+
101+
# how to initialize cahce?
102+
103+
# use the same cache to solve multiple linear problems
104+
function (cache::LinearCache)(x,A,b,args...;u0=nothing,kwargs...)
105+
set_A(cache, A)
106+
set_b(cache, b)
107+
108+
if u0 == nothing
109+
x = zero(x)
110+
else
111+
x = u0
112+
end
113+
set_u(cache, x)
114+
115+
x = solve(cache)
116+
return x
117+
end

test/runtests.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,23 @@ using Test
2424

2525
# Factorization
2626
for alg in (:LUFactorization, :QRFactorization, :SVDFactorization,
27-
:KrylovJL)
27+
:KrylovJL,
28+
# :KrylovKitJL,
29+
)
2830
@eval begin
2931
@test $A * solve($prob, $alg();) $b
30-
$alg()($x,$A,$b)
32+
$alg()($x, $A, $b)
33+
@test $A * $x $b
34+
35+
cache = SciMLBase.init($prob, $alg())
36+
cache($x, $A, $b)
3137
@test $A * $x $b
3238
end
3339
end
3440

3541
# test on some ODEProblem
36-
using OrdinaryDiffEq
37-
using DiffEqProblemLibrary.ODEProblemLibrary
38-
ODEProblemLibrary.importodeproblems()
39-
40-
# add this problem to DiffEqProblemLibrary
42+
# using OrdinaryDiffEq
43+
# # add this problem to DiffEqProblemLibrary
4144
# kx = 1
4245
# kt = 1
4346
# ut(x,t) = sin(kx*pi*x)*cos(kt*pi*t)
@@ -50,8 +53,12 @@ using Test
5053
# func = ODEFunction(dudt!)
5154
# prob = ODEProblem(func,u0,tspn)
5255

53-
prob = ODEProblemLibrary.prob_ode_linear
54-
sol = solve(prob, Rodas5(linsolve=SVDFactorization()); saveat=0.1)
55-
@show sol.retcode
56+
# using OrdinaryDiffEq
57+
# using DiffEqProblemLibrary.ODEProblemLibrary
58+
# ODEProblemLibrary.importodeproblems()
59+
# prob = ODEProblemLibrary.prob_ode_linear
60+
# @show prob
61+
# sol = solve(prob, Rodas5(linsolve=KrylovJL()); saveat=0.1)
62+
# @show sol.retcode
5663

5764
end

0 commit comments

Comments
 (0)