Skip to content

Commit e3f05ea

Browse files
committed
added new interfaces. in-place tests failing
1 parent 8f5f5b9 commit e3f05ea

File tree

2 files changed

+82
-50
lines changed

2 files changed

+82
-50
lines changed

src/common.jl

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,54 @@
1-
struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tr,Tl}
1+
struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tl,Tr}
22
A::TA
33
b::Tb
44
u::Tu
55
p::Tp
66
alg::Talg
7-
cacheval::Tc
7+
cacheval::Tc # store alg cache here
88
isfresh::Bool
9-
Pr::Tr
109
Pl::Tl
11-
# k::Tk # iteration count
10+
Pr::Tr
1211
end
1312

1413
function set_A(cache, A) # and ! to function name
1514
@set! cache.A = A
1615
@set! cache.isfresh = true
16+
return cache
1717
end
1818

1919
function set_b(cache, b)
2020
@set! cache.b = b
21+
return cache
2122
end
2223

2324
function set_u(cache, u)
2425
@set! cache.u = u
26+
return cache
2527
end
2628

2729
function set_p(cache, p)
2830
@set! cache.p = p
2931
# @set! cache.isfresh = true
32+
return cache
3033
end
3134

32-
function set_cacheval(cache::LinearCache, alg)
35+
function set_cacheval(cache, alg_cache)
3336
if cache.isfresh
34-
@set! cache.cacheval = alg
37+
@set! cache.cacheval = alg_cache
3538
@set! cache.isfresh = false
3639
end
3740
return cache
3841
end
3942

40-
function SciMLBase.init(
41-
prob::LinearProblem,
42-
alg,
43-
args...;
44-
alias_A = false,
45-
alias_b = false,
46-
kwargs...,
47-
)
43+
#function init_cacheval(cacheval, alg::SciMLLinearSolveAlgorithm)
44+
#
45+
# return
46+
#end
47+
48+
function SciMLBase.init(prob::LinearProblem, alg, args...;
49+
alias_A = false, alias_b = false,
50+
kwargs...,
51+
)
4852
@unpack A, b, u0, p = prob
4953

5054
if u0 == nothing
@@ -58,8 +62,8 @@ function SciMLBase.init(
5862
fact = nothing
5963
Tfact = Any
6064
end
61-
Pr = LinearAlgebra.I
6265
Pl = LinearAlgebra.I
66+
Pr = LinearAlgebra.I
6367

6468
A = alias_A ? A : deepcopy(A)
6569
b = alias_b ? b : deepcopy(b)
@@ -71,8 +75,8 @@ function SciMLBase.init(
7175
typeof(p),
7276
typeof(alg),
7377
Tfact,
74-
typeof(Pr),
7578
typeof(Pl),
79+
typeof(Pr),
7680
}(
7781
A,
7882
b,
@@ -81,8 +85,8 @@ function SciMLBase.init(
8185
alg,
8286
fact,
8387
true,
84-
Pr,
8588
Pl,
89+
Pr,
8690
)
8791
return cache
8892
end
@@ -92,26 +96,35 @@ SciMLBase.solve(prob::LinearProblem, alg, args...; kwargs...) =
9296

9397
SciMLBase.solve(cache) = solve(cache, cache.alg)
9498

99+
function (alg::SciMLLinearSolveAlgorithm)(prob::LinearProblem,args...;
100+
u0=nothing,kwargs...)
101+
x = solve(prob, alg, args...; kwargs...)
102+
return x
103+
end
104+
95105
function (alg::SciMLLinearSolveAlgorithm)(x,A,b,args...;u0=nothing,kwargs...)
96106
prob = LinearProblem(A,b;u0=x)
97-
x = solve(prob,alg,args...;kwargs...)
107+
x = alg(prob, args...; kwargs...)
98108
return x
99109
end
100110

101-
# how to initialize cahce?
102-
103-
# use the same cache to solve multiple linear problems
104-
function (cache::LinearCache)(x,A,b,args...;u0=nothing,kwargs...)
105-
set_A(cache, A)
106-
set_b(cache, b)
111+
function (cache::LinearCache)(prob::LinearProblem,args...;u0=nothing,kwargs...)
107112

108-
if u0 == nothing
109-
x = zero(x)
110-
else
111-
x = u0
113+
if prob.u0 == nothing
114+
prob.u0 = zero(x)
112115
end
113-
set_u(cache, x)
114116

115-
x = solve(cache)
117+
cache = set_A(cache, prob.A)
118+
cache = set_b(cache, prob.b)
119+
cache = set_u(cache, prob.u0)
120+
121+
x = solve(cache,args...;kwargs...)
122+
return x
123+
end
124+
125+
function (cache::LinearCache)(x,A,b,args...;u0=nothing,kwargs...)
126+
127+
prob = LinearProblem(A,b;u0=x)
128+
x = cache(prob, args...; kwargs...)
116129
return x
117130
end

test/runtests.jl

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33

44
@testset "LinearSolve.jl" begin
55
using LinearAlgebra
6-
n = 100
6+
n = 32
77
dx = 2/(n-1)
88

99
xx = Array(range(start=-1,stop=1,length=n))
@@ -19,25 +19,6 @@ using Test
1919
u = R * uu
2020
b = A * u #R * bb
2121

22-
x = zero(b)
23-
prob = LinearProblem(A, b;u0=x)
24-
25-
# Factorization
26-
for alg in (:LUFactorization, :QRFactorization, :SVDFactorization,
27-
:KrylovJL,
28-
# :KrylovKitJL,
29-
)
30-
@eval begin
31-
@test $A * solve($prob, $alg();) $b
32-
$alg()($x, $A, $b)
33-
@test $A * $x $b
34-
35-
cache = SciMLBase.init($prob, $alg())
36-
cache($x, $A, $b)
37-
@test $A * $x $b
38-
end
39-
end
40-
4122
# test on some ODEProblem
4223
# using OrdinaryDiffEq
4324
# # add this problem to DiffEqProblemLibrary
@@ -53,6 +34,44 @@ using Test
5334
# func = ODEFunction(dudt!)
5435
# prob = ODEProblem(func,u0,tspn)
5536

37+
x = zero(b)
38+
A1 = A; b1 = b
39+
A2 = 2A; b2 = 3b
40+
A3 = 3A; b3 = 2b
41+
42+
prob1 = LinearProblem(A1, b1; u0=x)
43+
prob2 = LinearProblem(A2, b2; u0=x)
44+
prob3 = LinearProblem(A3, b3; u0=x)
45+
46+
for alg in (
47+
:LUFactorization,
48+
:QRFactorization,
49+
:SVDFactorization,
50+
51+
# :DefaultLinSolve,
52+
53+
:KrylovJL,
54+
# :KrylovJL,
55+
# :KrylovKitJL,
56+
57+
)
58+
@eval begin
59+
y = solve($prob1, $alg())
60+
@test $A1 * y $b1
61+
@test $A1 * $x $b1
62+
63+
y = $alg()($x, $A2, $b2)
64+
@test $A2 * y $b2
65+
@test $A2 * $x $b2
66+
67+
cache = SciMLBase.init($prob1, $alg())
68+
y = cache($x, $A3, $b3)
69+
@test $A3 * $x $b3
70+
@test $A3 * y $b3
71+
end
72+
end
73+
74+
5675
# using OrdinaryDiffEq
5776
# using DiffEqProblemLibrary.ODEProblemLibrary
5877
# ODEProblemLibrary.importodeproblems()

0 commit comments

Comments
 (0)