Skip to content

Commit f336a16

Browse files
improve defaults and add simplelu
1 parent 3d12bd8 commit f336a16

File tree

4 files changed

+171
-18
lines changed

4 files changed

+171
-18
lines changed

LICENSE

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2021 Jonathan <[email protected]> and contributors
3+
Copyright (c) 2021 SciML and contributors
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal
@@ -19,3 +19,6 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1919
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
SOFTWARE.
22+
23+
SimpleLU.jl is derived from https://github.com/JuliaGNI/SimpleSolvers.jl under
24+
an MIT license.

src/LinearSolve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
3636

3737
include("common.jl")
3838
include("factorization.jl")
39+
include("simplelu.jl")
3940
include("iterative_wrappers.jl")
4041
include("preconditioners.jl")
4142
include("default.jl")
@@ -45,7 +46,8 @@ const IS_OPENBLAS = Ref(true)
4546
isopenblas() = IS_OPENBLAS[]
4647

4748
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
48-
GenericLUFactorization, RFLUFactorization, UMFPACKFactorization, KLUFactorization
49+
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
50+
UMFPACKFactorization, KLUFactorization
4951
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
5052
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
5153
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES

src/default.jl

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ function defaultalg(A,b)
1111
# whether MKL or OpenBLAS is being used
1212
if (A === nothing && !isgpu(b)) || A isa Matrix
1313
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
14-
ArrayInterface.can_setindex(b) && (length(b) <= 100 ||
15-
(isopenblas() && length(b) <= 500)
16-
)
17-
alg = RFLUFactorization()
14+
ArrayInterface.can_setindex(b)
15+
if length(b) <= 10
16+
alg = GenericLUFactorization()
17+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
18+
alg = RFLUFactorization()
19+
else
20+
alg = LUFactorization()
21+
end
1822
else
1923
alg = LUFactorization()
2024
end
@@ -58,12 +62,18 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
5862
# it makes sense according to the benchmarks, which is dependent on
5963
# whether MKL or OpenBLAS is being used
6064
if A isa Matrix
61-
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
62-
ArrayInterface.can_setindex(cache.b) && (size(A,1) <= 100 ||
63-
(isopenblas() && size(A,1) <= 500)
64-
)
65-
alg = RFLUFactorization()
66-
SciMLBase.solve(cache, alg, args...; kwargs...)
65+
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
66+
ArrayInterface.can_setindex(b)
67+
if length(b) <= 10
68+
alg = GenericLUFactorization()
69+
SciMLBase.solve(cache, alg, args...; kwargs...)
70+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
71+
alg = RFLUFactorization()
72+
SciMLBase.solve(cache, alg, args...; kwargs...)
73+
else
74+
alg = LUFactorization()
75+
SciMLBase.solve(cache, alg, args...; kwargs...)
76+
end
6777
else
6878
alg = LUFactorization()
6979
SciMLBase.solve(cache, alg, args...; kwargs...)
@@ -110,12 +120,18 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
110120
# it makes sense according to the benchmarks, which is dependent on
111121
# whether MKL or OpenBLAS is being used
112122
if A isa Matrix
113-
if eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
114-
ArrayInterface.can_setindex(b) && (size(A,1) <= 100 ||
115-
(isopenblas() && size(A,1) <= 500)
116-
)
117-
alg = RFLUFactorization()
118-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
123+
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
124+
ArrayInterface.can_setindex(b)
125+
if length(b) <= 10
126+
alg = GenericLUFactorization()
127+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
128+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500))
129+
alg = RFLUFactorization()
130+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
131+
else
132+
alg = LUFactorization()
133+
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
134+
end
119135
else
120136
alg = LUFactorization()
121137
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)

src/simplelu.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
## From https://github.com/JuliaGNI/SimpleSolvers.jl/blob/master/src/linear/lu_solver.jl
2+
3+
mutable struct LUSolver{T}
4+
n::Int
5+
A::Matrix{T}
6+
b::Vector{T}
7+
x::Vector{T}
8+
pivots::Vector{Int}
9+
perms::Vector{Int}
10+
info::Int
11+
12+
LUSolver{T}(n) where {T} = new(n, zeros(T, n, n), zeros(T, n), zeros(T, n), zeros(Int, n), zeros(Int, n), 0)
13+
end
14+
15+
function LUSolver(A::Matrix{T}) where {T}
16+
n = LinearAlgebra.checksquare(A)
17+
lu = LUSolver{eltype(A)}(n)
18+
lu.A .= A
19+
lu
20+
end
21+
22+
function LUSolver(A::Matrix{T}, b::Vector{T}) where {T}
23+
n = LinearAlgebra.checksquare(A)
24+
@assert n == length(b)
25+
lu = LUSolver{eltype(A)}(n)
26+
lu.A .= A
27+
lu.b .= b
28+
lu
29+
end
30+
31+
function simplelu_factorize!(lu::LUSolver{T}, pivot=true) where {T}
32+
A = lu.A
33+
34+
begin
35+
@inbounds for i in eachindex(lu.perms)
36+
lu.perms[i] = i
37+
end
38+
39+
@inbounds for k = 1:lu.n
40+
# find index max
41+
kp = k
42+
if pivot
43+
amax = real(zero(T))
44+
for i = k:lu.n
45+
absi = abs(A[i,k])
46+
if absi > amax
47+
kp = i
48+
amax = absi
49+
end
50+
end
51+
end
52+
lu.pivots[k] = kp
53+
lu.perms[k], lu.perms[kp] = lu.perms[kp], lu.perms[k]
54+
55+
if A[kp,k] != 0
56+
if k != kp
57+
# Interchange
58+
for i = 1:lu.n
59+
tmp = A[k,i]
60+
A[k,i] = A[kp,i]
61+
A[kp,i] = tmp
62+
end
63+
end
64+
# Scale first column
65+
Akkinv = inv(A[k,k])
66+
for i = k+1:lu.n
67+
A[i,k] *= Akkinv
68+
end
69+
elseif lu.info == 0
70+
lu.info = k
71+
end
72+
# Update the rest
73+
for j = k+1:lu.n
74+
for i = k+1:lu.n
75+
A[i,j] -= A[i,k]*A[k,j]
76+
end
77+
end
78+
end
79+
80+
lu.info
81+
end
82+
end
83+
84+
function simplelu_solve!(lu::LUSolver{T}) where {T}
85+
local s::T
86+
87+
@inbounds for i = 1:lu.n
88+
lu.x[i] = lu.b[lu.perms[i]]
89+
end
90+
91+
@inbounds for i = 2:lu.n
92+
s = 0
93+
for j = 1:i-1
94+
s += lu.A[i,j] * lu.x[j]
95+
end
96+
lu.x[i] -= s
97+
end
98+
99+
lu.x[lu.n] /= lu.A[lu.n,lu.n]
100+
@inbounds for i = lu.n-1:-1:1
101+
s = 0
102+
for j = i+1:lu.n
103+
s += lu.A[i,j] * lu.x[j]
104+
end
105+
lu.x[i] -= s
106+
lu.x[i] /= lu.A[i,i]
107+
end
108+
109+
lu.b .= lu.x
110+
111+
lu.x
112+
end
113+
114+
### Wrapper
115+
116+
struct SimpleLUFactorization <: AbstractFactorization
117+
pivot::Bool
118+
SimpleLUFactorization(pivot=true) = new(pivot)
119+
end
120+
121+
function SciMLBase.solve(cache::LinearCache, alg::SimpleLUFactorization; kwargs...)
122+
if cache.isfresh
123+
cache.cacheval.A = cache.A
124+
simplelu_factorize!(cache.cacheval, alg.pivot)
125+
end
126+
cache.cacheval.b = cache.b
127+
cache.cacheval.x = cache.u
128+
y = simplelu_solve!(cache.cacheval)
129+
SciMLBase.build_linear_solution(alg,y,nothing,cache)
130+
end
131+
132+
init_cacheval(alg::SimpleLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = LUSolver(convert(AbstractMatrix,A))

0 commit comments

Comments
 (0)