Skip to content

Commit 9d1def4

Browse files
committed
abstractfactorization
1 parent 5ac02ac commit 9d1def4

File tree

2 files changed

+17
-32
lines changed

2 files changed

+17
-32
lines changed

src/LinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using Reexport
1616
@reexport using SciMLBase
1717

1818
abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
19+
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
1920

2021
include("common.jl")
2122
include("default.jl")

src/factorization.jl

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11

2+
3+
function SciMLBase.solve(cache::LinearCache, alg::AbstractFactorization)
4+
if cache.isfresh
5+
fact = init_cacheval(alg, cache.A, cache.b, cache.u)
6+
cache = set_cacheval(cache, fact)
7+
end
8+
9+
ldiv!(cache.u,cache.cacheval, cache.b)
10+
end
11+
212
## LUFactorization
313

4-
struct LUFactorization{P} <: SciMLLinearSolveAlgorithm
14+
struct LUFactorization{P} <: AbstractFactorization
515
pivot::P
616
end
717

@@ -14,25 +24,16 @@ function LUFactorization()
1424
LUFactorization(pivot)
1525
end
1626

17-
function init_cacheval(A, alg::LUFactorization)
27+
function init_cacheval(alg::LUFactorization, A, b, u)
1828
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
1929
error("LU is not defined for $(typeof(A))")
2030
fact = lu!(A, alg.pivot)
2131
return fact
2232
end
2333

24-
function SciMLBase.solve(cache::LinearCache, alg::LUFactorization)
25-
if cache.isfresh
26-
fact = init_cacheval(cache.A, alg)
27-
cache = set_cacheval(cache, fact)
28-
end
29-
30-
ldiv!(cache.u,cache.cacheval, cache.b)
31-
end
32-
3334
## QRFactorization
3435

35-
struct QRFactorization{P} <: SciMLLinearSolveAlgorithm
36+
struct QRFactorization{P} <: AbstractFactorization
3637
pivot::P
3738
blocksize::Int
3839
end
@@ -46,45 +47,28 @@ function QRFactorization()
4647
QRFactorization(pivot, 16)
4748
end
4849

49-
function init_cacheval(A, alg::QRFactorization)
50+
function init_cacheval(alg::QRFactorization, A, b, u)
5051
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
5152
error("QR is not defined for $(typeof(A))")
5253

5354
fact = qr!(A.A, alg.pivot; blocksize = alg.blocksize)
5455
return fact
5556
end
5657

57-
function SciMLBase.solve(cache::LinearCache, alg::QRFactorization)
58-
if cache.isfresh
59-
fact = init_cacheval(cache.A, alg)
60-
cache = set_cacheval(cache, fact)
61-
end
62-
63-
ldiv!(cache.u,cache.cacheval, cache.b)
64-
end
65-
6658
## SVDFactorization
6759

68-
struct SVDFactorization{A} <: SciMLLinearSolveAlgorithm
60+
struct SVDFactorization{A} <: AbstractFactorization
6961
full::Bool
7062
alg::A
7163
end
7264

7365
SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())
7466

75-
function init_cacheval(A, alg::SVDFactorization)
67+
function init_cacheval(alg::SVDFactorization, A, b, u)
7668
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
7769
error("SVD is not defined for $(typeof(A))")
7870

7971
fact = svd!(A; full = alg.full, alg = alg.alg)
8072
return fact
8173
end
8274

83-
function SciMLBase.solve(cache::LinearCache, alg::SVDFactorization)
84-
if cache.isfresh
85-
fact = init_cacheval(cache.A, alg)
86-
cache = set_cacheval(cache, fact)
87-
end
88-
89-
ldiv!(cache.u,cache.cacheval, cache.b)
90-
end

0 commit comments

Comments
 (0)