Skip to content

Commit 0dc8240

Browse files
get interface cleaned and ready to merge
1 parent 580221f commit 0dc8240

File tree

5 files changed

+68
-69
lines changed

5 files changed

+68
-69
lines changed

src/LinearSolve.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,21 @@ include("factorization.jl")
2626
include("wrappers.jl")
2727
include("default.jl")
2828

29-
export LUFactorization, SVDFactorization, QRFactorization, DefaultFactorization
29+
const IS_OPENBLAS = Ref(true)
30+
isopenblas() = IS_OPENBLAS[]
31+
32+
function __init__()
33+
@static if VERSION < v"1.7beta"
34+
blas = BLAS.vendor()
35+
IS_OPENBLAS[] = blas == :openblas64 || blas == :openblas
36+
else
37+
IS_OPENBLAS[] = occursin("openblas", BLAS.get_config().loaded_libs[1].libname)
38+
end
39+
end
40+
41+
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization
3042
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB,
31-
KrylovJL_MINRES,
43+
KrylovJL_MINRES,
3244
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
3345
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES
3446
export DefaultLinSolve

src/default.jl

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,35 @@
11
## Default algorithm
22

3-
struct DefaultLinSolve{Ta} <: SciMLLinearSolveAlgorithm
4-
linalg::Ta
5-
ifopenblas::Union{Bool,Nothing}
6-
isset::Bool # true => do nothing, false => find alg
7-
end
8-
9-
DefaultLinSolve() = DefaultLinSolve(nothing, nothing, true)
10-
11-
function isopenblas()
12-
@static if VERSION < v"1.7beta"
13-
blas = BLAS.vendor()
14-
blas == :openblas64 || blas == :openblas
15-
else
16-
occursin("openblas", BLAS.get_config().loaded_libs[1].libname)
17-
end
18-
end
19-
20-
function SciMLBase.solve(cache::LinearCache, alg::DefaultLinSolve,
3+
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
214
args...; kwargs...)
225
@unpack A = cache
23-
24-
if alg.isset
25-
linalg =
26-
if A isa Matrix
27-
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 ||
28-
(p.openblas && size(A,1) <= 500)
29-
)
30-
DefaultFactorization(;fact_alg=:(RecursiveFactorization.lu!))
31-
else
32-
LUFactorization()
33-
end
34-
elseif A isa Tridiagonal
35-
DefaultFactorization(;fact_alg=lu!)
36-
elseif A isa SymTridiagonal
37-
DefaultFactorization(;fact_alg=ldlt!)
38-
elseif A isa SparseMatrixCSC
39-
LUFactorization()
40-
elseif ArrayInterface.isstructured(A)
41-
DefaultFactorization()
42-
elseif !(A isa AbstractDiffEqOperator)
43-
QRFactorization()
44-
else
45-
IterativeSolversJL_GMRES()
46-
end
47-
48-
@set! alg.linalg = linalg
6+
if A isa Matrix
7+
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 ||
8+
(isopenblas() && size(A,1) <= 500)
9+
)
10+
alg = GenericFactorization(;fact_alg=:(RecursiveFactorization.lu!))
11+
SciMLBase.solve(cache, alg, args...; kwargs...)
12+
else
13+
alg = LUFactorization()
14+
SciMLBase.solve(cache, alg, args...; kwargs...)
15+
end
16+
elseif A isa Tridiagonal
17+
alg = GenericFactorization(;fact_alg=lu!)
18+
SciMLBase.solve(cache, alg, args...; kwargs...)
19+
elseif A isa SymTridiagonal
20+
alg = GenericFactorization(;fact_alg=ldlt!)
21+
SciMLBase.solve(cache, alg, args...; kwargs...)
22+
elseif A isa SparseMatrixCSC
23+
alg = LUFactorization()
24+
SciMLBase.solve(cache, alg, args...; kwargs...)
25+
elseif ArrayInterface.isstructured(A)
26+
alg = GenericFactorization()
27+
SciMLBase.solve(cache, alg, args...; kwargs...)
28+
elseif !(A isa AbstractDiffEqOperator)
29+
alg = QRFactorization()
30+
SciMLBase.solve(cache, alg, args...; kwargs...)
31+
else
32+
alg = IterativeSolversJL_GMRES()
33+
SciMLBase.solve(cache, alg, args...; kwargs...)
4934
end
50-
51-
SciMLBase.solve(cache, alg.linalg, args...; kwargs...)
5235
end

src/factorization.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
function SciMLBase.solve(cache::LinearCache, alg::AbstractFactorization)
32
if cache.isfresh
43
fact = init_cacheval(alg, cache.A, cache.b, cache.u)
@@ -71,18 +70,18 @@ function init_cacheval(alg::SVDFactorization, A, b, u)
7170
return fact
7271
end
7372

74-
## DefaultFactorization
73+
## GenericFactorization
7574

76-
struct DefaultFactorization{F} <: AbstractFactorization
75+
struct GenericFactorization{F} <: AbstractFactorization
7776
fact_alg::F
7877
end
7978

80-
DefaultFactorization(;fact_alg = LinearAlgebra.factorize) =
81-
DefaultFactorization(fact_alg)
79+
GenericFactorization(;fact_alg = LinearAlgebra.factorize) =
80+
GenericFactorization(fact_alg)
8281

83-
function init_cacheval(alg::DefaultFactorization, A, b, u)
82+
function init_cacheval(alg::GenericFactorization, A, b, u)
8483
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
85-
error("DefaultFactorization is not defined for $(typeof(A))")
84+
error("GenericFactorization is not defined for $(typeof(A))")
8685

8786
fact = alg.fact_alg(A)
8887
return fact

src/wrappers.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
#TODO: composed preconditioners, preconditioner setter for cache,
2+
#TODO: composed preconditioners, preconditioner setter for cache,
33
# detailed tests for wrappers
44

55
## Preconditioners
@@ -65,7 +65,7 @@ KrylovJL_MINRES(args...;kwargs...) =
6565
KrylovJL(args...; KrylovAlg=Krylov.minres!, kwargs...)
6666

6767
function get_KrylovJL_solver(KrylovAlg)
68-
KS =
68+
KS =
6969
if (KrylovAlg === Krylov.lsmr! ) Krylov.LsmrSolver
7070
elseif (KrylovAlg === Krylov.cgs! ) Krylov.CgsSolver
7171
elseif (KrylovAlg === Krylov.usymlq! ) Krylov.UsymlqSolver
@@ -260,4 +260,3 @@ function SciMLBase.solve(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
260260

261261
return cache.u
262262
end
263-

test/runtests.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using LinearSolve, LinearAlgebra
22
using Test
33

44
n = 8
5-
65
A = Matrix(I,n,n)
76
b = ones(n)
87
A1 = A/1; b1 = rand(n); x1 = zero(b)
@@ -11,7 +10,7 @@ A2 = A/2; b2 = rand(n); x2 = zero(b)
1110
prob1 = LinearProblem(A1, b1; u0=x1)
1211
prob2 = LinearProblem(A2, b2; u0=x2)
1312

14-
function test_interface(alg, prob1, prob2, prob3)
13+
function test_interface(alg, prob1, prob2)
1514
A1 = prob1.A; b1 = prob1.b; x1 = prob1.u0
1615
A2 = prob2.A; b2 = prob2.b; x2 = prob2.u0
1716

@@ -22,26 +21,32 @@ function test_interface(alg, prob1, prob2, prob3)
2221
y = solve(cache)
2322
@test A1 * y b1
2423

25-
cache = LinearSolve.set_A(cache,A2)
24+
cache = LinearSolve.set_A(cache,copy(A2))
2625
y = solve(cache)
2726
@test A2 * y b1
2827

28+
@show A2, b2
29+
2930
cache = LinearSolve.set_b(cache,b2)
3031
y = solve(cache)
32+
@show cache.A, cache.b, y
3133
@test A2 * y b2
3234

3335
return
3436
end
3537

38+
alg = GenericFactorization(fact_alg=cholesky!)
39+
test_interface(alg, prob1, prob2)
40+
3641
@testset "Concrete Factorizations" begin
3742
for alg in (
3843
LUFactorization(),
3944
QRFactorization(),
4045
SVDFactorization(),
41-
# DefaultLinSolve()
46+
#nothing
4247
)
4348
@testset "$alg" begin
44-
test_interface(alg, prob1, prob2, prob3)
49+
test_interface(alg, prob1, prob2)
4550
end
4651
end
4752
end
@@ -50,16 +55,17 @@ end
5055
for fact_alg in (
5156
lu, lu!,
5257
qr, qr!,
53-
cholesky, cholesky!,
58+
cholesky,
59+
#cholesky!,
5460
# ldlt, ldlt!,
5561
bunchkaufman, bunchkaufman!,
5662
lq, lq!,
5763
svd, svd!,
5864
LinearAlgebra.factorize,
5965
)
6066
@testset "fact_alg = $fact_alg" begin
61-
alg = DefaultFactorization(fact_alg=fact_alg)
62-
test_interface(alg, prob1, prob2, prob3)
67+
alg = GenericFactorization(fact_alg=fact_alg)
68+
test_interface(alg, prob1, prob2)
6369
end
6470
end
6571
end
@@ -75,7 +81,7 @@ end
7581
("MINRES",KrylovJL_MINRES(kwargs...)),
7682
)
7783
@testset "$(alg[1])" begin
78-
test_interface(alg[2], prob1, prob2, prob3)
84+
test_interface(alg[2], prob1, prob2)
7985
end
8086
end
8187
end
@@ -88,10 +94,10 @@ end
8894
("CG", IterativeSolversJL_CG(kwargs...)),
8995
("GMRES",IterativeSolversJL_GMRES(kwargs...)),
9096
# ("BICGSTAB",IterativeSolversJL_BICGSTAB(kwargs...)),
91-
("MINRES",IterativeSolversJL_MINRES(kwargs...)),
97+
# ("MINRES",IterativeSolversJL_MINRES(kwargs...)),
9298
)
9399
@testset "$(alg[1])" begin
94-
test_interface(alg[2], prob1, prob2, prob3)
100+
test_interface(alg[2], prob1, prob2)
95101
end
96102
end
97103
end

0 commit comments

Comments
 (0)