Skip to content

Commit 2c004a2

Browse files
Remove ipiv allocation from GenericLUFactorization
1 parent 35662db commit 2c004a2

File tree

3 files changed

+126
-29
lines changed

3 files changed

+126
-29
lines changed

src/LinearSolve.jl

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ end
140140

141141
const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64}
142142

143+
include("generic_lufact.jl")
143144
include("common.jl")
144145
include("extension_algs.jl")
145146
include("factorization.jl")
@@ -171,28 +172,6 @@ end
171172
@inline _notsuccessful(F) = hasmethod(LinearAlgebra.issuccess, (typeof(F),)) ?
172173
!LinearAlgebra.issuccess(F) : false
173174

174-
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
175-
kwargs...)
176-
quote
177-
if cache.isfresh
178-
fact = do_factorization(alg, cache.A, cache.b, cache.u)
179-
cache.cacheval = fact
180-
181-
# If factorization was not successful, return failure. Don't reset `isfresh`
182-
if _notsuccessful(fact)
183-
return SciMLBase.build_linear_solution(
184-
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
185-
end
186-
187-
cache.isfresh = false
188-
end
189-
190-
y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))),
191-
cache.b)
192-
return SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success)
193-
end
194-
end
195-
196175
# Solver Specific Traits
197176
## Needs Square Matrix
198177
"""

src/factorization.jl

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
1+
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
2+
kwargs...)
3+
quote
4+
if cache.isfresh
5+
fact = do_factorization(alg, cache.A, cache.b, cache.u)
6+
cache.cacheval = fact
7+
8+
# If factorization was not successful, return failure. Don't reset `isfresh`
9+
if _notsuccessful(fact)
10+
return SciMLBase.build_linear_solution(
11+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
12+
end
13+
14+
cache.isfresh = false
15+
end
16+
17+
y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))),
18+
cache.b)
19+
return SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success)
20+
end
21+
end
22+
123
macro get_cacheval(cache, algsym)
224
quote
325
if $(esc(cache)).alg isa DefaultLinearSolver
@@ -8,6 +30,8 @@ macro get_cacheval(cache, algsym)
830
end
931
end
1032

33+
const PREALLOCATED_IPIV = Vector{LinearAlgebra.BlasInt}(undef, 0)
34+
1135
_ldiv!(x, A, b) = ldiv!(x, A, b)
1236

1337
_ldiv!(x, A, b::SVector) = (x .= A \ b)
@@ -41,8 +65,7 @@ function LinearSolve.init_cacheval(
4165
alg::RFLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr,
4266
maxiters::Int,
4367
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
44-
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
45-
PREALLOCATED_LU, ipiv
68+
PREALLOCATED_LU, PREALLOCATED_IPIV
4669
end
4770

4871
function LinearSolve.init_cacheval(alg::RFLUFactorization,
@@ -144,14 +167,47 @@ function do_factorization(alg::LUFactorization, A, b, u)
144167
return fact
145168
end
146169

147-
function do_factorization(alg::GenericLUFactorization, A, b, u)
170+
function init_cacheval(
171+
alg::GenericLUFactorization, A, b, u, Pl, Pr,
172+
maxiters::Int, abstol, reltol, verbose::Bool,
173+
assumptions::OperatorAssumptions)
174+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
175+
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), ipiv
176+
end
177+
178+
function init_cacheval(
179+
alg::GenericLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr,
180+
maxiters::Int, abstol, reltol, verbose::Bool,
181+
assumptions::OperatorAssumptions)
182+
PREALLOCATED_LU, PREALLOCATED_IPIV
183+
end
184+
185+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::GenericLUFactorization;
186+
kwargs...)
187+
A = cache.A
148188
A = convert(AbstractMatrix, A)
149-
fact = LinearAlgebra.generic_lufact!(A, alg.pivot, check = false)
150-
return fact
189+
fact, ipiv = LinearSolve.@get_cacheval(cache, :GenericLUFactorization)
190+
191+
if cache.isfresh
192+
if length(ipiv) != min(size(A)...)
193+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
194+
end
195+
fact = generic_lufact!(A, alg.pivot; check = false, ipiv)
196+
cache.cacheval = (fact, ipiv)
197+
198+
if !LinearAlgebra.issuccess(fact)
199+
return SciMLBase.build_linear_solution(
200+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
201+
end
202+
203+
cache.isfresh = false
204+
end
205+
y = ldiv!(cache.u, LinearSolve.@get_cacheval(cache, :GenericLUFactorization)[1], cache.b)
206+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
151207
end
152208

153209
function init_cacheval(
154-
alg::Union{LUFactorization, GenericLUFactorization}, A, b, u, Pl, Pr,
210+
alg::LUFactorization, A, b, u, Pl, Pr,
155211
maxiters::Int, abstol, reltol, verbose::Bool,
156212
assumptions::OperatorAssumptions)
157213
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
@@ -171,7 +227,7 @@ end
171227

172228
const PREALLOCATED_LU = ArrayInterface.lu_instance(rand(1, 1))
173229

174-
function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
230+
function init_cacheval(alg::LUFactorization,
175231
A::Matrix{Float64}, b, u, Pl, Pr,
176232
maxiters::Int, abstol, reltol, verbose::Bool,
177233
assumptions::OperatorAssumptions)

src/generic_lufact.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# From LinearAlgebra.lu.jl
2+
# Modified to be non-allocating
3+
function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = lupivottype(T);
4+
check::Bool = true, ipiv = Vector{BlasInt}(undef, min(size(A)))) where {T}
5+
check && LinearAlgebra.LAPACK.chkfinite(A)
6+
# Extract values
7+
m, n = size(A)
8+
minmn = min(m,n)
9+
10+
# Initialize variables
11+
info = 0
12+
13+
@inbounds begin
14+
for k = 1:minmn
15+
# find index max
16+
kp = k
17+
if pivot === LinearAlgebra.RowMaximum() && k < m
18+
amax = abs(A[k, k])
19+
for i = k+1:m
20+
absi = abs(A[i,k])
21+
if absi > amax
22+
kp = i
23+
amax = absi
24+
end
25+
end
26+
elseif pivot === LinearAlgebra.RowNonZero()
27+
for i = k:m
28+
if !iszero(A[i,k])
29+
kp = i
30+
break
31+
end
32+
end
33+
end
34+
ipiv[k] = kp
35+
if !iszero(A[kp,k])
36+
if k != kp
37+
# Interchange
38+
for i = 1:n
39+
tmp = A[k,i]
40+
A[k,i] = A[kp,i]
41+
A[kp,i] = tmp
42+
end
43+
end
44+
# Scale first column
45+
Akkinv = inv(A[k,k])
46+
for i = k+1:m
47+
A[i,k] *= Akkinv
48+
end
49+
elseif info == 0
50+
info = k
51+
end
52+
# Update the rest
53+
for j = k+1:n
54+
for i = k+1:m
55+
A[i,j] -= A[i,k]*A[k,j]
56+
end
57+
end
58+
end
59+
end
60+
check && LinearAlgebra.checknonsingular(info, pivot)
61+
return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info))
62+
end

0 commit comments

Comments
 (0)