Skip to content

Commit 7ed66bd

Browse files
Merge pull request #99 from SciML/genericlu
add GenericLUFactorizations
2 parents 991a1da + b16a674 commit 7ed66bd

File tree

5 files changed

+190
-19
lines changed

5 files changed

+190
-19
lines changed

LICENSE

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2021 Jonathan <[email protected]> and contributors
3+
Copyright (c) 2021 SciML and contributors
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal
@@ -19,3 +19,6 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1919
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
SOFTWARE.
22+
23+
SimpleLU.jl is derived from https://github.com/JuliaGNI/SimpleSolvers.jl under
24+
an MIT license.

src/LinearSolve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
3636

3737
include("common.jl")
3838
include("factorization.jl")
39+
include("simplelu.jl")
3940
include("iterative_wrappers.jl")
4041
include("preconditioners.jl")
4142
include("default.jl")
@@ -45,7 +46,8 @@ const IS_OPENBLAS = Ref(true)
4546
isopenblas() = IS_OPENBLAS[]
4647

4748
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
48-
RFLUFactorization, UMFPACKFactorization, KLUFactorization
49+
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
50+
UMFPACKFactorization, KLUFactorization
4951
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
5052
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
5153
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES

src/default.jl

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ function defaultalg(A,b)
1111
# whether MKL or OpenBLAS is being used
1212
if (A === nothing && !isgpu(b)) || A isa Matrix
1313
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
14-
ArrayInterface.can_setindex(b) && (length(b) <= 100 ||
15-
(isopenblas() && length(b) <= 500)
16-
)
17-
alg = RFLUFactorization()
14+
ArrayInterface.can_setindex(b)
15+
if length(b) <= 10
16+
alg = GenericLUFactorization()
17+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
18+
alg = RFLUFactorization()
19+
else
20+
alg = LUFactorization()
21+
end
1822
else
1923
alg = LUFactorization()
2024
end
@@ -58,12 +62,19 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
5862
# it makes sense according to the benchmarks, which is dependent on
5963
# whether MKL or OpenBLAS is being used
6064
if A isa Matrix
61-
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
62-
ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
63-
(isopenblas() && size(A,1) <= 500)
64-
)
65-
alg = RFLUFactorization()
66-
SciMLBase.solve(cache, alg, args...; kwargs...)
65+
b = cache.b
66+
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
67+
ArrayInterface.can_setindex(b)
68+
if length(b) <= 10
69+
alg = GenericLUFactorization()
70+
SciMLBase.solve(cache, alg, args...; kwargs...)
71+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
72+
alg = RFLUFactorization()
73+
SciMLBase.solve(cache, alg, args...; kwargs...)
74+
else
75+
alg = LUFactorization()
76+
SciMLBase.solve(cache, alg, args...; kwargs...)
77+
end
6778
else
6879
alg = LUFactorization()
6980
SciMLBase.solve(cache, alg, args...; kwargs...)
@@ -110,12 +121,18 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
110121
# it makes sense according to the benchmarks, which is dependent on
111122
# whether MKL or OpenBLAS is being used
112123
if A isa Matrix
113-
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
114-
ArrayInterface.can_setindex(b) && (size(A,1) <= 100 ||
115-
(isopenblas() && size(A,1) <= 500)
116-
)
117-
alg = RFLUFactorization()
118-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
124+
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
125+
ArrayInterface.can_setindex(b)
126+
if length(b) <= 10
127+
alg = GenericLUFactorization()
128+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
129+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
130+
alg = RFLUFactorization()
131+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
132+
else
133+
alg = LUFactorization()
134+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
135+
end
119136
else
120137
alg = LUFactorization()
121138
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)

src/factorization.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ struct LUFactorization{P} <: AbstractFactorization
2525
pivot::P
2626
end
2727

28+
struct GenericLUFactorization{P} <: AbstractFactorization
29+
pivot::P
30+
end
31+
2832
function LUFactorization()
2933
pivot = @static if VERSION < v"1.7beta"
3034
Val(true)
@@ -34,6 +38,15 @@ function LUFactorization()
3438
LUFactorization(pivot)
3539
end
3640

41+
function GenericLUFactorization()
42+
pivot = @static if VERSION < v"1.7beta"
43+
Val(true)
44+
else
45+
RowMaximum()
46+
end
47+
GenericLUFactorization(pivot)
48+
end
49+
3750
function do_factorization(alg::LUFactorization, A, b, u)
3851
A = convert(AbstractMatrix,A)
3952
if A isa SparseMatrixCSC
@@ -44,7 +57,13 @@ function do_factorization(alg::LUFactorization, A, b, u)
4457
return fact
4558
end
4659

47-
init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
60+
function do_factorization(alg::GenericLUFactorization, A, b, u)
61+
A = convert(AbstractMatrix,A)
62+
fact = LinearAlgebra.generic_lufact!(A, alg.pivot)
63+
return fact
64+
end
65+
66+
init_cacheval(alg::Union{LUFactorization,GenericLUFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
4867

4968
# This could be a GenericFactorization perhaps?
5069
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization

src/simplelu.jl

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
## From https://github.com/JuliaGNI/SimpleSolvers.jl/blob/master/src/linear/lu_solver.jl
2+
3+
mutable struct LUSolver{T}
4+
n::Int
5+
A::Matrix{T}
6+
b::Vector{T}
7+
x::Vector{T}
8+
pivots::Vector{Int}
9+
perms::Vector{Int}
10+
info::Int
11+
12+
LUSolver{T}(n) where {T} = new(n, zeros(T, n, n), zeros(T, n), zeros(T, n), zeros(Int, n), zeros(Int, n), 0)
13+
end
14+
15+
function LUSolver(A::Matrix{T}) where {T}
16+
n = LinearAlgebra.checksquare(A)
17+
lu = LUSolver{eltype(A)}(n)
18+
lu.A .= A
19+
lu
20+
end
21+
22+
function LUSolver(A::Matrix{T}, b::Vector{T}) where {T}
23+
n = LinearAlgebra.checksquare(A)
24+
@assert n == length(b)
25+
lu = LUSolver{eltype(A)}(n)
26+
lu.A .= A
27+
lu.b .= b
28+
lu
29+
end
30+
31+
function simplelu_factorize!(lu::LUSolver{T}, pivot=true) where {T}
32+
A = lu.A
33+
34+
begin
35+
@inbounds for i in eachindex(lu.perms)
36+
lu.perms[i] = i
37+
end
38+
39+
@inbounds for k = 1:lu.n
40+
# find index max
41+
kp = k
42+
if pivot
43+
amax = real(zero(T))
44+
for i = k:lu.n
45+
absi = abs(A[i,k])
46+
if absi > amax
47+
kp = i
48+
amax = absi
49+
end
50+
end
51+
end
52+
lu.pivots[k] = kp
53+
lu.perms[k], lu.perms[kp] = lu.perms[kp], lu.perms[k]
54+
55+
if A[kp,k] != 0
56+
if k != kp
57+
# Interchange
58+
for i = 1:lu.n
59+
tmp = A[k,i]
60+
A[k,i] = A[kp,i]
61+
A[kp,i] = tmp
62+
end
63+
end
64+
# Scale first column
65+
Akkinv = inv(A[k,k])
66+
for i = k+1:lu.n
67+
A[i,k] *= Akkinv
68+
end
69+
elseif lu.info == 0
70+
lu.info = k
71+
end
72+
# Update the rest
73+
for j = k+1:lu.n
74+
for i = k+1:lu.n
75+
A[i,j] -= A[i,k]*A[k,j]
76+
end
77+
end
78+
end
79+
80+
lu.info
81+
end
82+
end
83+
84+
function simplelu_solve!(lu::LUSolver{T}) where {T}
85+
@inbounds for i = 1:lu.n
86+
lu.x[i] = lu.b[lu.perms[i]]
87+
end
88+
89+
@inbounds for i = 2:lu.n
90+
s = zero(T)
91+
for j = 1:i-1
92+
s += lu.A[i,j] * lu.x[j]
93+
end
94+
lu.x[i] -= s
95+
end
96+
97+
lu.x[lu.n] /= lu.A[lu.n,lu.n]
98+
@inbounds for i = lu.n-1:-1:1
99+
s = zero(T)
100+
for j = i+1:lu.n
101+
s += lu.A[i,j] * lu.x[j]
102+
end
103+
lu.x[i] -= s
104+
lu.x[i] /= lu.A[i,i]
105+
end
106+
107+
copyto!(lu.b,lu.x)
108+
109+
lu.x
110+
end
111+
112+
### Wrapper
113+
114+
struct SimpleLUFactorization <: AbstractFactorization
115+
pivot::Bool
116+
SimpleLUFactorization(pivot=true) = new(pivot)
117+
end
118+
119+
function SciMLBase.solve(cache::LinearCache, alg::SimpleLUFactorization; kwargs...)
120+
if cache.isfresh
121+
cache.cacheval.A = cache.A
122+
simplelu_factorize!(cache.cacheval, alg.pivot)
123+
end
124+
cache.cacheval.b = cache.b
125+
cache.cacheval.x = cache.u
126+
y = simplelu_solve!(cache.cacheval)
127+
SciMLBase.build_linear_solution(alg,y,nothing,cache)
128+
end
129+
130+
init_cacheval(alg::SimpleLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = LUSolver(convert(AbstractMatrix,A))

0 commit comments

Comments
 (0)