Skip to content

Commit 5ac02ac

Browse files
committed
in place tests working for factorizations. onto KSP wrappers
1 parent e3f05ea commit 5ac02ac

File tree

5 files changed

+151
-101
lines changed

5 files changed

+151
-101
lines changed

src/common.jl

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tl,Tr}
44
u::Tu
55
p::Tp
66
alg::Talg
7-
cacheval::Tc # store alg cache here
8-
isfresh::Bool
7+
cacheval::Tc # store alg cache here
8+
isfresh::Bool # false => cacheval is set wrt A, true => update cacheval wrt A
99
Pl::Tl
1010
Pr::Tr
1111
end
@@ -28,7 +28,7 @@ end
2828

2929
function set_p(cache, p)
3030
@set! cache.p = p
31-
# @set! cache.isfresh = true
31+
# @set! cache.isfresh = true
3232
return cache
3333
end
3434

@@ -40,10 +40,7 @@ function set_cacheval(cache, alg_cache)
4040
return cache
4141
end
4242

43-
#function init_cacheval(cacheval, alg::SciMLLinearSolveAlgorithm)
44-
#
45-
# return
46-
#end
43+
init_cacheval(A, alg::SciMLLinearSolveAlgorithm) = nothing
4744

4845
function SciMLBase.init(prob::LinearProblem, alg, args...;
4946
alias_A = false, alias_b = false,
@@ -55,13 +52,10 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
5552
u0 = zero(b)
5653
end
5754

58-
if alg isa LUFactorization
59-
fact = lu_instance(A)
60-
Tfact = typeof(fact)
61-
else
62-
fact = nothing
63-
Tfact = Any
64-
end
55+
cacheval = init_cacheval(prob.A, alg)
56+
Tc = cacheval == nothing ? Any : typeof(cacheval)
57+
isfresh = cacheval == nothing
58+
6559
Pl = LinearAlgebra.I
6660
Pr = LinearAlgebra.I
6761

@@ -74,7 +68,7 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
7468
typeof(u0),
7569
typeof(p),
7670
typeof(alg),
77-
Tfact,
71+
Tc,
7872
typeof(Pl),
7973
typeof(Pr),
8074
}(
@@ -83,48 +77,52 @@ function SciMLBase.init(prob::LinearProblem, alg, args...;
8377
u0,
8478
p,
8579
alg,
86-
fact,
87-
true,
80+
cacheval,
81+
isfresh,
8882
Pl,
8983
Pr,
9084
)
9185
return cache
9286
end
9387

94-
SciMLBase.solve(prob::LinearProblem, alg, args...; kwargs...) =
95-
solve(init(prob, alg, args...; kwargs...))
88+
SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
89+
args...; kwargs...) = solve(init(prob, alg, args...; kwargs...))
90+
91+
SciMLBase.solve(cache::LinearCache, args...; kwargs...) =
92+
solve(cache, cache.alg, args...; kwargs...)
9693

97-
SciMLBase.solve(cache) = solve(cache, cache.alg)
94+
## make alg callable
9895

99-
function (alg::SciMLLinearSolveAlgorithm)(prob::LinearProblem,args...;
100-
u0=nothing,kwargs...)
96+
function (alg::SciMLLinearSolveAlgorithm)(prob::LinearProblem,args...; kwargs...)
10197
x = solve(prob, alg, args...; kwargs...)
10298
return x
10399
end
104100

105101
function (alg::SciMLLinearSolveAlgorithm)(x,A,b,args...;u0=nothing,kwargs...)
106-
prob = LinearProblem(A,b;u0=x)
102+
prob = LinearProblem(A, b; u0=x)
107103
x = alg(prob, args...; kwargs...)
108104
return x
109105
end
110106

111-
function (cache::LinearCache)(prob::LinearProblem,args...;u0=nothing,kwargs...)
107+
## make cache callable - and reuse
112108

113-
if prob.u0 == nothing
109+
function (cache::LinearCache)(prob::LinearProblem, args...; kwargs...)
110+
111+
if(prob.A != cache.A) cache = set_A(cache, prob.A) end
112+
if(prob.b != cache.b) cache = set_b(cache, prob.b) end
113+
114+
if(prob.u0 == nothing)
114115
prob.u0 = zero(x)
115116
end
116117

117-
cache = set_A(cache, prob.A)
118-
cache = set_b(cache, prob.b)
119118
cache = set_u(cache, prob.u0)
120-
121-
x = solve(cache,args...;kwargs...)
119+
x = solve(cache, args...; kwargs...)
122120
return x
123121
end
124122

125-
function (cache::LinearCache)(x,A,b,args...;u0=nothing,kwargs...)
123+
function (cache::LinearCache)(x, A, b, args...; kwargs...)
126124

127-
prob = LinearProblem(A,b;u0=x)
125+
prob = LinearProblem(A, b; u0=x)
128126
x = cache(prob, args...; kwargs...)
129127
return x
130128
end

src/factorization.jl

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11

2+
## LUFactorization
3+
24
struct LUFactorization{P} <: SciMLLinearSolveAlgorithm
35
pivot::P
46
end
@@ -12,14 +14,24 @@ function LUFactorization()
1214
LUFactorization(pivot)
1315
end
1416

17+
function init_cacheval(A, alg::LUFactorization)
18+
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
19+
error("LU is not defined for $(typeof(A))")
20+
fact = lu!(A, alg.pivot)
21+
return fact
22+
end
23+
1524
function SciMLBase.solve(cache::LinearCache, alg::LUFactorization)
16-
cache.A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
17-
error("LU is not defined for $(typeof(prob.A))")
18-
fact = lu!(cache.A, alg.pivot)
19-
cache = set_cacheval(cache, fact)
25+
if cache.isfresh
26+
fact = init_cacheval(cache.A, alg)
27+
cache = set_cacheval(cache, fact)
28+
end
29+
2030
ldiv!(cache.u,cache.cacheval, cache.b)
2131
end
2232

33+
## QRFactorization
34+
2335
struct QRFactorization{P} <: SciMLLinearSolveAlgorithm
2436
pivot::P
2537
blocksize::Int
@@ -34,25 +46,45 @@ function QRFactorization()
3446
QRFactorization(pivot, 16)
3547
end
3648

49+
function init_cacheval(A, alg::QRFactorization)
50+
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
51+
error("QR is not defined for $(typeof(A))")
52+
53+
fact = qr!(A.A, alg.pivot; blocksize = alg.blocksize)
54+
return fact
55+
end
56+
3757
function SciMLBase.solve(cache::LinearCache, alg::QRFactorization)
38-
cache.A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
39-
error("QR is not defined for $(typeof(prob.A))")
40-
fact = qr!(cache.A.A, alg.pivot; blocksize = alg.blocksize)
41-
cache = set_cacheval(cache, fact)
58+
if cache.isfresh
59+
fact = init_cacheval(cache.A, alg)
60+
cache = set_cacheval(cache, fact)
61+
end
62+
4263
ldiv!(cache.u,cache.cacheval, cache.b)
4364
end
4465

66+
## SVDFactorization
67+
4568
struct SVDFactorization{A} <: SciMLLinearSolveAlgorithm
4669
full::Bool
4770
alg::A
4871
end
4972

5073
SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())
5174

75+
function init_cacheval(A, alg::SVDFactorization)
76+
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
77+
error("SVD is not defined for $(typeof(A))")
78+
79+
fact = svd!(A; full = alg.full, alg = alg.alg)
80+
return fact
81+
end
82+
5283
function SciMLBase.solve(cache::LinearCache, alg::SVDFactorization)
53-
cache.A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
54-
error("SVD is not defined for $(typeof(cache.A))")
55-
fact = svd!(cache.A; full = alg.full, alg = alg.alg)
56-
cache = set_cacheval(cache, fact)
84+
if cache.isfresh
85+
fact = init_cacheval(cache.A, alg)
86+
cache = set_cacheval(cache, fact)
87+
end
88+
5789
ldiv!(cache.u,cache.cacheval, cache.b)
5890
end

src/wrappers.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
## Krylov.jl
22

3-
# place Krylov.CGsolver in LinearCache.cacheval, and resule
4-
53
struct KrylovJL{F,A,K} <: SciMLLinearSolveAlgorithm
64
solver::F
75
args::A
86
kwargs::K
97
end
108

11-
function KrylovJL(args...; solver = Krylov.bicgstab, kwargs...)
9+
function KrylovJL(args...; solver = Krylov.gmres, kwargs...)
1210
return KrylovJL(solver, args, kwargs)
1311
end
1412

13+
# place Krylov.CGsolver in LinearCache.cacheval for reuse
14+
function init_cacheval(prob::LinearProblem, alg::KrylovJL)
15+
if alg.solver === Krylov.cg!
16+
elseif alg.solver === Krylov.gmres!
17+
elseif alg.solver === Krylov.bicgstab!
18+
end
19+
return
20+
end
21+
22+
# KrylovJL failing in-place
1523
function SciMLBase.solve(cache::LinearCache, alg::KrylovJL,args...;kwargs...)
1624
@unpack A, b, u, Pr, Pl = cache
1725
u, stats = alg.solver(A, b, args...; M=Pl, N=Pr, kwargs...)

test/ode_test.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
2+
3+
n = 10
4+
dx = 2/(n-1)
5+
xx = Array(range(start=-1,stop=1,length=n))
6+
AA = Tridiagonal(ones(n-1),-2ones(n),ones(n-1))/(dx*dx) # rank-deficient sys
7+
uu = @. sin(pi*xx)
8+
bb = AA * uu #@. -(pi^2)*uu
9+
10+
id = Matrix(I,n,n)
11+
R = id[2:end-1,:]
12+
13+
x = R * xx
14+
A = R * AA * R' # full rank system
15+
u = R * uu
16+
b = A * u #R * bb
17+
18+
# test on some ODEProblem
19+
using OrdinaryDiffEq
20+
# add this problem to DiffEqProblemLibrary
21+
kx = 1
22+
kt = 1
23+
ut(x,t) = sin(kx*pi*x)*cos(kt*pi*t)
24+
ic(x) = ut(x,0.0)
25+
f(x,t) = ut(x,t)*(kx*pi)^2 - sin(kx*pi*x)*sin(kt*pi*t)*(kt*pi)
26+
u0 = ic.(x)
27+
dudt!(du,u,p,t) = -A*u + f.(x,t)
28+
dt = 0.01
29+
tspn = (0.0,1.0)
30+
func = ODEFunction(dudt!)
31+
prob = ODEProblem(func,u0,tspn)
32+
33+
34+
using OrdinaryDiffEq
35+
using DiffEqProblemLibrary.ODEProblemLibrary
36+
ODEProblemLibrary.importodeproblems()
37+
prob = ODEProblemLibrary.prob_ode_linear
38+
@show prob
39+
sol = solve(prob, Rodas5(linsolve=KrylovJL()); saveat=0.1)
40+
@show sol.retcode
41+

test/runtests.jl

Lines changed: 26 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,17 @@ using Test
33

44
@testset "LinearSolve.jl" begin
55
using LinearAlgebra
6-
n = 32
7-
dx = 2/(n-1)
6+
n = 8
87

9-
xx = Array(range(start=-1,stop=1,length=n))
10-
AA = Tridiagonal(ones(n-1),-2ones(n),ones(n-1))/(dx*dx) # rank-deficient sys
11-
uu = @. sin(pi*xx)
12-
bb = AA * uu #@. -(pi^2)*uu
8+
A = Matrix(I,n,n)
9+
b = ones(n)
10+
A1 = A/1; b1 = rand(n); x1 = zero(b)
11+
A2 = A/2; b2 = rand(n); x2 = zero(b)
12+
A3 = A/3; b3 = rand(n); x3 = zero(b)
1313

14-
id = Matrix(I,n,n)
15-
R = id[2:end-1,:]
16-
17-
x = R * xx
18-
A = R * AA * R' # full rank system
19-
u = R * uu
20-
b = A * u #R * bb
21-
22-
# test on some ODEProblem
23-
# using OrdinaryDiffEq
24-
# # add this problem to DiffEqProblemLibrary
25-
# kx = 1
26-
# kt = 1
27-
# ut(x,t) = sin(kx*pi*x)*cos(kt*pi*t)
28-
# ic(x) = ut(x,0.0)
29-
# f(x,t) = ut(x,t)*(kx*pi)^2 - sin(kx*pi*x)*sin(kt*pi*t)*(kt*pi)
30-
# u0 = ic.(x)
31-
# dudt!(du,u,p,t) = -A*u + f.(x,t)
32-
# dt = 0.01
33-
# tspn = (0.0,1.0)
34-
# func = ODEFunction(dudt!)
35-
# prob = ODEProblem(func,u0,tspn)
36-
37-
x = zero(b)
38-
A1 = A; b1 = b
39-
A2 = 2A; b2 = 3b
40-
A3 = 3A; b3 = 2b
41-
42-
prob1 = LinearProblem(A1, b1; u0=x)
43-
prob2 = LinearProblem(A2, b2; u0=x)
44-
prob3 = LinearProblem(A3, b3; u0=x)
14+
prob1 = LinearProblem(A1, b1; u0=x1)
15+
prob2 = LinearProblem(A2, b2; u0=x2)
16+
prob3 = LinearProblem(A3, b3; u0=x3)
4517

4618
for alg in (
4719
:LUFactorization,
@@ -50,34 +22,33 @@ using Test
5022

5123
# :DefaultLinSolve,
5224

53-
:KrylovJL,
5425
# :KrylovJL,
26+
# :IterativeSolvers.jl
5527
# :KrylovKitJL,
5628

5729
)
5830
@eval begin
5931
y = solve($prob1, $alg())
60-
@test $A1 * y $b1
61-
@test $A1 * $x $b1
32+
@test $A1 * y $b1
33+
@test $A1 * $x1 $b1
6234

63-
y = $alg()($x, $A2, $b2)
64-
@test $A2 * y $b2
65-
@test $A2 * $x $b2
35+
y = $alg()($x2, $A2, $b2)
36+
@test $A2 * y $b2
37+
@test $A2 * $x2 $b2
6638

6739
cache = SciMLBase.init($prob1, $alg())
68-
y = cache($x, $A3, $b3)
69-
@test $A3 * $x $b3
70-
@test $A3 * y $b3
71-
end
72-
end
40+
y = cache($x3, $A1, $b1)
41+
@test $A1 * y $b1
42+
@test $A1 * $x3 $b1
7343

44+
y = cache($x3, $A1, $b2)
45+
@test $A1 * y $b2
46+
@test $A1 * $x3 $b2
7447

75-
# using OrdinaryDiffEq
76-
# using DiffEqProblemLibrary.ODEProblemLibrary
77-
# ODEProblemLibrary.importodeproblems()
78-
# prob = ODEProblemLibrary.prob_ode_linear
79-
# @show prob
80-
# sol = solve(prob, Rodas5(linsolve=KrylovJL()); saveat=0.1)
81-
# @show sol.retcode
48+
y = cache($x3, $A2, $b3)
49+
@test $A2 * y $b3
50+
@test $A2 * $x3 $b3
51+
end
52+
end
8253

8354
end

0 commit comments

Comments
 (0)