Skip to content

Commit cdfe484

Browse files
Merge pull request #21 from vpuri3/vp/updates
LinearSolve interface
2 parents eb61248 + c273237 commit cdfe484

File tree

10 files changed

+574
-103
lines changed

10 files changed

+574
-103
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
*.jl.mem
44
/docs/build/
55
Manifest.toml
6+
7+
*.swp

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,20 @@ version = "0.1.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
89
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
10+
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
911
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
1013
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1114
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1215
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
16+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1317
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
1418

1519
[compat]
1620
ArrayInterface = "3"
21+
IterativeSolvers = "0.9.2"
1722
Krylov = "0.7"
1823
Reexport = "1"
1924
SciMLBase = "1.18.6"
@@ -22,7 +27,9 @@ UnPack = "1"
2227
julia = "1"
2328

2429
[extras]
30+
DiffEqProblemLibrary = "a077e3f3-b75c-5d7f-a0c6-6bc4c8ec64a9"
31+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2532
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2633

2734
[targets]
28-
test = ["Test"]
35+
test = ["Test", "OrdinaryDiffEq", "DiffEqProblemLibrary"]

src/LinearSolve.jl

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,48 @@
11
module LinearSolve
22

3-
using ArrayInterface: lu_instance
3+
using ArrayInterface
4+
using RecursiveFactorization
45
using Base: cache_dependencies, Bool
5-
using Krylov
66
using LinearAlgebra
7-
using Reexport
7+
using SparseArrays
88
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
99
using Setfield
1010
using UnPack
1111

12+
# wrap
13+
import Krylov
14+
import KrylovKit
15+
import IterativeSolvers
16+
17+
using Reexport
1218
@reexport using SciMLBase
1319

14-
abstract type SciMLLinearSolveAlgorithm end
20+
abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
21+
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
22+
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
1523

1624
include("common.jl")
1725
include("factorization.jl")
18-
include("krylov.jl")
26+
include("wrappers.jl")
27+
include("default.jl")
28+
29+
const IS_OPENBLAS = Ref(true)
30+
isopenblas() = IS_OPENBLAS[]
31+
32+
function __init__()
33+
@static if VERSION < v"1.7beta"
34+
blas = BLAS.vendor()
35+
IS_OPENBLAS[] = blas == :openblas64 || blas == :openblas
36+
else
37+
IS_OPENBLAS[] = occursin("openblas", BLAS.get_config().loaded_libs[1].libname)
38+
end
39+
end
1940

20-
export LUFactorization, SVDFactorization, QRFactorization
21-
export KrylovJL
41+
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization
42+
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB,
43+
KrylovJL_MINRES,
44+
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
45+
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES
46+
export DefaultLinSolve
2247

2348
end

src/common.jl

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,90 @@
1-
struct LinearCache{TA,Tb,Tp,Talg,Tc,Tr,Tl}
1+
struct LinearCache{TA,Tb,Tu,Tp,Talg,Tc,Tl,Tr}
22
A::TA
33
b::Tb
4+
u::Tu
45
p::Tp
56
alg::Talg
6-
cacheval::Tc
7-
isfresh::Bool
8-
Pr::Tr
9-
Pl::Tl
7+
cacheval::Tc # store alg cache here
8+
isfresh::Bool # false => cacheval is set wrt A, true => update cacheval wrt A
9+
Pl::Tl # store final preconditioner here. not being used rn
10+
Pr::Tr # wrappers are using preconditioner in cache.alg for now
1011
end
1112

12-
function set_A(cache, A)
13+
function set_A(cache::LinearCache, A)
1314
@set! cache.A = A
1415
@set! cache.isfresh = true
16+
return cache
1517
end
1618

17-
function set_b(cache, b)
19+
function set_b(cache::LinearCache, b)
1820
@set! cache.b = b
21+
return cache
22+
end
23+
24+
function set_u(cache::LinearCache, u)
25+
@set! cache.u = u
26+
return cache
1927
end
2028

21-
function set_p(cache, p)
29+
function set_p(cache::LinearCache, p)
2230
@set! cache.p = p
23-
# @set! cache.isfresh = true
31+
# @set! cache.isfresh = true
32+
return cache
2433
end
2534

26-
function set_cacheval(cache::LinearCache, alg)
35+
function set_cacheval(cache::LinearCache, alg_cache)
2736
if cache.isfresh
28-
@set! cache.cacheval = alg
37+
@set! cache.cacheval = alg_cache
2938
@set! cache.isfresh = false
3039
end
3140
return cache
3241
end
3342

34-
function SciMLBase.init(
35-
prob::LinearProblem,
36-
alg,
37-
args...;
38-
alias_A = false,
39-
alias_b = false,
40-
kwargs...,
41-
)
42-
@unpack A, b, p = prob
43-
if alg isa LUFactorization
44-
fact = lu_instance(A)
45-
Tfact = typeof(fact)
46-
else
47-
fact = nothing
48-
Tfact = Any
49-
end
50-
Pr = nothing
51-
Pl = nothing
43+
init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
44+
45+
function SciMLBase.init(prob::LinearProblem, alg, args...;
46+
alias_A = false, alias_b = false,
47+
kwargs...,
48+
)
49+
@unpack A, b, u0, p = prob
5250

53-
A = alias_A ? A : copy(A)
54-
b = alias_b ? b : copy(b)
51+
u0 = (u0 === nothing) ? zero(b) : u0
52+
53+
cacheval = init_cacheval(alg, A, b, u0)
54+
isfresh = cacheval === nothing
55+
Tc = isfresh ? Any : typeof(cacheval)
56+
57+
Pl = LinearAlgebra.I
58+
Pr = LinearAlgebra.I
59+
60+
A = alias_A ? A : deepcopy(A)
61+
b = alias_b ? b : deepcopy(b)
5562

5663
cache = LinearCache{
5764
typeof(A),
5865
typeof(b),
66+
typeof(u0),
5967
typeof(p),
6068
typeof(alg),
61-
Tfact,
62-
typeof(Pr),
69+
Tc,
6370
typeof(Pl),
71+
typeof(Pr),
6472
}(
6573
A,
6674
b,
75+
u0,
6776
p,
6877
alg,
69-
fact,
70-
true,
71-
Pr,
78+
cacheval,
79+
isfresh,
7280
Pl,
81+
Pr,
7382
)
7483
return cache
7584
end
7685

77-
SciMLBase.solve(prob::LinearProblem, alg, args...; kwargs...) =
78-
solve(init(prob, alg, args...; kwargs...))
86+
SciMLBase.solve(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
87+
args...; kwargs...) = solve(init(prob, alg, args...; kwargs...))
7988

80-
SciMLBase.solve(cache) = solve(cache, cache.alg)
89+
SciMLBase.solve(cache::LinearCache, args...; kwargs...) =
90+
solve(cache, cache.alg, args...; kwargs...)

src/default.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
## Default algorithm
2+
3+
function SciMLBase.solve(cache::LinearCache, alg::Nothing,
4+
args...; kwargs...)
5+
@unpack A = cache
6+
if A isa Matrix
7+
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 ||
8+
(isopenblas() && size(A,1) <= 500)
9+
)
10+
alg = GenericFactorization(;fact_alg=:(RecursiveFactorization.lu!))
11+
SciMLBase.solve(cache, alg, args...; kwargs...)
12+
else
13+
alg = LUFactorization()
14+
SciMLBase.solve(cache, alg, args...; kwargs...)
15+
end
16+
elseif A isa Tridiagonal
17+
alg = GenericFactorization(;fact_alg=lu!)
18+
SciMLBase.solve(cache, alg, args...; kwargs...)
19+
elseif A isa SymTridiagonal
20+
alg = GenericFactorization(;fact_alg=ldlt!)
21+
SciMLBase.solve(cache, alg, args...; kwargs...)
22+
elseif A isa SparseMatrixCSC
23+
alg = LUFactorization()
24+
SciMLBase.solve(cache, alg, args...; kwargs...)
25+
elseif ArrayInterface.isstructured(A)
26+
alg = GenericFactorization()
27+
SciMLBase.solve(cache, alg, args...; kwargs...)
28+
elseif !(A isa AbstractDiffEqOperator)
29+
alg = QRFactorization()
30+
SciMLBase.solve(cache, alg, args...; kwargs...)
31+
else
32+
alg = IterativeSolversJL_GMRES()
33+
SciMLBase.solve(cache, alg, args...; kwargs...)
34+
end
35+
end

src/factorization.jl

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1+
function SciMLBase.solve(cache::LinearCache, alg::AbstractFactorization)
2+
if cache.isfresh
3+
fact = init_cacheval(alg, cache.A, cache.b, cache.u)
4+
cache = set_cacheval(cache, fact)
5+
end
6+
7+
ldiv!(cache.u,cache.cacheval, cache.b)
8+
end
9+
10+
## LUFactorization
111

2-
struct LUFactorization{P} <: AbstractLinearAlgorithm
12+
struct LUFactorization{P} <: AbstractFactorization
313
pivot::P
414
end
515

@@ -12,14 +22,16 @@ function LUFactorization()
1222
LUFactorization(pivot)
1323
end
1424

15-
function SciMLBase.solve(cache::LinearCache, alg::LUFactorization)
16-
cache.A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
17-
error("LU is not defined for $(typeof(prob.A))")
18-
cache = set_cacheval(cache, lu!(cache.A, alg.pivot))
19-
ldiv!(cache.cacheval, cache.b)
25+
function init_cacheval(alg::LUFactorization, A, b, u)
26+
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
27+
error("LU is not defined for $(typeof(A))")
28+
fact = lu!(A, alg.pivot)
29+
return fact
2030
end
2131

22-
struct QRFactorization{P} <: AbstractLinearAlgorithm
32+
## QRFactorization
33+
34+
struct QRFactorization{P} <: AbstractFactorization
2335
pivot::P
2436
blocksize::Int
2537
end
@@ -33,26 +45,44 @@ function QRFactorization()
3345
QRFactorization(pivot, 16)
3446
end
3547

36-
function SciMLBase.solve(cache::LinearCache, alg::QRFactorization)
37-
cache.A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
38-
error("QR is not defined for $(typeof(prob.A))")
39-
cache = set_cacheval(
40-
cache,
41-
qr!(cache.A.A, alg.pivot; blocksize = alg.blocksize),
42-
)
43-
ldiv!(cache.cacheval, cache.b)
48+
function init_cacheval(alg::QRFactorization, A, b, u)
49+
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
50+
error("QR is not defined for $(typeof(A))")
51+
52+
fact = qr!(A.A, alg.pivot; blocksize = alg.blocksize)
53+
return fact
4454
end
4555

46-
struct SVDFactorization{A} <: AbstractLinearAlgorithm
56+
## SVDFactorization
57+
58+
struct SVDFactorization{A} <: AbstractFactorization
4759
full::Bool
4860
alg::A
4961
end
5062

5163
SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())
5264

53-
function SciMLBase.solve(cache::LinearCache, alg::SVDFactorization)
54-
cache.A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
55-
error("SVD is not defined for $(typeof(prob.A))")
56-
cache = set_cacheval(cache, svd!(cache.A; full = alg.full, alg = alg.alg))
57-
ldiv!(cache.cacheval, cache.b)
65+
function init_cacheval(alg::SVDFactorization, A, b, u)
66+
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
67+
error("SVD is not defined for $(typeof(A))")
68+
69+
fact = svd!(A; full = alg.full, alg = alg.alg)
70+
return fact
71+
end
72+
73+
## GenericFactorization
74+
75+
struct GenericFactorization{F} <: AbstractFactorization
76+
fact_alg::F
77+
end
78+
79+
GenericFactorization(;fact_alg = LinearAlgebra.factorize) =
80+
GenericFactorization(fact_alg)
81+
82+
function init_cacheval(alg::GenericFactorization, A, b, u)
83+
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
84+
error("GenericFactorization is not defined for $(typeof(A))")
85+
86+
fact = alg.fact_alg(A)
87+
return fact
5888
end

src/krylov.jl

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)