Skip to content

Commit 580221f

Browse files
fix up the interface
1 parent 0174a5e commit 580221f

File tree

2 files changed

+76
-134
lines changed

2 files changed

+76
-134
lines changed

src/common.jl

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tl,Tr}
44
u::Tu
55
p::Tp
66
alg::Talg
7-
cacheval::Tc # store alg cache here
7+
cacheval::Tc # store alg cache here
88
isfresh::Bool # false => cacheval is set wrt A, true => update cacheval wrt A
99
Pl::Tl # store final preconditioner here. not being used rn
1010
Pr::Tr # wrappers are using preconditioner in cache.alg for now
@@ -88,38 +88,3 @@ SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
8888

8989
SciMLBase.solve(cache::LinearCache, args...; kwargs...) =
9090
solve(cache, cache.alg, args...; kwargs...)
91-
92-
## make alg callable
93-
94-
function (alg::SciMLLinearSolveAlgorithm)(prob::LinearProblem,args...; kwargs...)
95-
x = solve(prob, alg, args...; kwargs...)
96-
return x
97-
end
98-
99-
function (alg::SciMLLinearSolveAlgorithm)(x,A,b,args...;u0=nothing,kwargs...)
100-
prob = LinearProblem(A, b; u0=x)
101-
x = alg(prob, args...; kwargs...)
102-
return x
103-
end
104-
105-
## make cache callable - and reuse
106-
107-
function (cache::LinearCache)(prob::LinearProblem, args...; kwargs...)
108-
109-
if(prob.A != cache.A) cache = set_A(cache, prob.A) end
110-
if(prob.b != cache.b) cache = set_b(cache, prob.b) end
111-
112-
if(prob.u0 == nothing)
113-
prob.u0 = zero(x)
114-
end
115-
116-
cache = set_u(cache, prob.u0)
117-
x = solve(cache, args...; kwargs...)
118-
return x
119-
end
120-
121-
function (cache::LinearCache)(x, A, b, args...; kwargs...)
122-
prob = LinearProblem(A, b; u0=x)
123-
x = cache(prob, args...; kwargs...)
124-
return x
125-
end

test/runtests.jl

Lines changed: 75 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,120 +1,97 @@
1-
using LinearSolve
1+
using LinearSolve, LinearAlgebra
22
using Test
33

4-
@testset "LinearSolve.jl" begin
5-
using LinearAlgebra
6-
n = 8
4+
n = 8
75

8-
A = Matrix(I,n,n)
9-
b = ones(n)
10-
A1 = A/1; b1 = rand(n); x1 = zero(b)
11-
A2 = A/2; b2 = rand(n); x2 = zero(b)
12-
A3 = A/3; b3 = rand(n); x3 = zero(b)
6+
A = Matrix(I,n,n)
7+
b = ones(n)
8+
A1 = A/1; b1 = rand(n); x1 = zero(b)
9+
A2 = A/2; b2 = rand(n); x2 = zero(b)
1310

14-
prob1 = LinearProblem(A1, b1; u0=x1)
15-
prob2 = LinearProblem(A2, b2; u0=x2)
16-
prob3 = LinearProblem(A3, b3; u0=x3)
11+
prob1 = LinearProblem(A1, b1; u0=x1)
12+
prob2 = LinearProblem(A2, b2; u0=x2)
1713

18-
function test_interface(alg, kwargs, prob1, prob2, prob3)
19-
A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0
20-
A2 = prob2.A; b2 = prob2.b; x2 = prob2.u0
21-
A3 = prob3.A; b3 = prob3.b; x3 = prob3.u0
14+
function test_interface(alg, prob1, prob2, prob3)
15+
A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0
16+
A2 = prob2.A; b2 = prob2.b; x2 = prob2.u0
2217

23-
@eval begin
24-
y = solve($prob1, $alg(;$kwargs...))
25-
@test $A1 * y $b1 # out of place
26-
@test $A1 * $x1 $b1 # in place
18+
y = solve(prob1, alg)
19+
@test A1 * y b1
2720

28-
y = $alg(;$kwargs...)($x2, $A2, $b2) # alg is callable
29-
@test $A2 * y $b2
30-
@test $A2 * $x2 $b2
21+
cache = SciMLBase.init(prob1,alg) # initialize cache
22+
y = solve(cache)
23+
@test A1 * y b1
3124

32-
cache = SciMLBase.init($prob1,
33-
$alg(;$kwargs...)) # initialize cache
34-
y = cache($x3, $A1, $b1) # cache is callable
35-
@test $A1 * y $b1
36-
@test $A1 * $x3 $b1
25+
cache = LinearSolve.set_A(cache,A2)
26+
y = solve(cache)
27+
@test A2 * y b1
3728

38-
y = cache($x3, $A1, $b2) # reuse factorization
39-
@test $A1 * y $b2 # with different RHS
40-
@test $A1 * $x3 $b2
29+
cache = LinearSolve.set_b(cache,b2)
30+
y = solve(cache)
31+
@test A2 * y b2
4132

42-
y = cache($x3, $A2, $b3) # new factorization
43-
@test $A2 * y $b3 # same old cache
44-
@test $A2 * $x3 $b3
45-
end
46-
47-
x1 .= 0.0
48-
x2 .= 0.0
49-
x3 .= 0.0
50-
51-
return
52-
end
53-
54-
@testset "factorization" begin
55-
kwargs = :()
56-
for alg in (
57-
:LUFactorization,
58-
:QRFactorization,
59-
:SVDFactorization,
60-
# :DefaultLinSolve
61-
)
62-
@testset "$alg" begin
63-
test_interface(alg, kwargs, prob1, prob2, prob3)
64-
end
65-
end
33+
return
34+
end
6635

67-
alg = :DefaultFactorization
36+
@testset "Concrete Factorizations" begin
37+
for alg in (
38+
LUFactorization(),
39+
QRFactorization(),
40+
SVDFactorization(),
41+
# DefaultLinSolve()
42+
)
6843
@testset "$alg" begin
69-
for fact_alg in (
70-
:lu, :lu!,
71-
:qr, :qr!,
72-
:cholesky, :cholesky!,
73-
# :ldlt, :ldlt!,
74-
:bunchkaufman, :bunchkaufman!,
75-
:lq, :lq!,
76-
:svd, :svd!,
77-
:(LinearAlgebra.factorize),
78-
)
79-
@testset "fact_alg = $fact_alg" begin
80-
kwargs = :(fact_alg=$fact_alg,)
81-
test_interface(alg, kwargs, prob1, prob2, prob3)
82-
end
83-
end
44+
test_interface(alg, prob1, prob2, prob3)
8445
end
85-
8646
end
47+
end
8748

88-
@testset "KrylovJL" begin
89-
kwargs = :(ifverbose=false, abstol=1e-8, reltol=1e-8, maxiter=30,
90-
gmres_restart=5)
91-
for alg in (
92-
:KrylovJL,
93-
:KrylovJL_CG,
94-
:KrylovJL_GMRES,
95-
# :KrylovJL_BICGSTAB,
96-
:KrylovJL_MINRES,
97-
)
98-
@testset "$alg" begin
99-
test_interface(alg, kwargs, prob1, prob2, prob3)
100-
end
49+
@testset "Generic Factorizations" begin
50+
for fact_alg in (
51+
lu, lu!,
52+
qr, qr!,
53+
cholesky, cholesky!,
54+
# ldlt, ldlt!,
55+
bunchkaufman, bunchkaufman!,
56+
lq, lq!,
57+
svd, svd!,
58+
LinearAlgebra.factorize,
59+
)
60+
@testset "fact_alg = $fact_alg" begin
61+
alg = DefaultFactorization(fact_alg=fact_alg)
62+
test_interface(alg, prob1, prob2, prob3)
10163
end
10264
end
65+
end
10366

104-
@testset "IterativeSolversJL" begin
105-
kwargs = :(ifverbose=false, abstol=1e-8, reltol=1e-8, maxiter=30,
106-
gmres_restart=5)
107-
for alg in (
108-
:IterativeSolversJL,
109-
:IterativeSolversJL_CG,
110-
:IterativeSolversJL_GMRES,
111-
# :IterativeSolversJL_BICGSTAB,
112-
:IterativeSolversJL_MINRES,
113-
)
114-
@testset "$alg" begin
115-
test_interface(alg, kwargs, prob1, prob2, prob3)
116-
end
67+
@testset "KrylovJL" begin
68+
kwargs = (;ifverbose=false, abstol=1e-8, reltol=1e-8, maxiter=30,
69+
gmres_restart=5)
70+
for alg in (
71+
("Default",KrylovJL(kwargs...)),
72+
("CG",KrylovJL_CG(kwargs...)),
73+
("GMRES",KrylovJL_GMRES(kwargs...)),
74+
# ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)),
75+
("MINRES",KrylovJL_MINRES(kwargs...)),
76+
)
77+
@testset "$(alg[1])" begin
78+
test_interface(alg[2], prob1, prob2, prob3)
11779
end
11880
end
81+
end
11982

83+
@testset "IterativeSolversJL" begin
84+
kwargs = (;ifverbose=false, abstol=1e-8, reltol=1e-8, maxiter=30,
85+
gmres_restart=5)
86+
for alg in (
87+
("Default", IterativeSolversJL(kwargs...)),
88+
("CG", IterativeSolversJL_CG(kwargs...)),
89+
("GMRES",IterativeSolversJL_GMRES(kwargs...)),
90+
# ("BICGSTAB",IterativeSolversJL_BICGSTAB(kwargs...)),
91+
("MINRES",IterativeSolversJL_MINRES(kwargs...)),
92+
)
93+
@testset "$(alg[1])" begin
94+
test_interface(alg[2], prob1, prob2, prob3)
95+
end
96+
end
12097
end

0 commit comments

Comments
 (0)