Skip to content

Commit e101fd3

Browse files
Merge pull request #618 from SciML/generic_lufact_alloc
Remove ipiv allocation from GenericLUFactorization
2 parents 35662db + 62cdcbf commit e101fd3

File tree

4 files changed

+229
-41
lines changed

4 files changed

+229
-41
lines changed

ext/LinearSolveSparseArraysExt.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,15 @@ const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0
6464
Int[], Float64[]))
6565

6666
function LinearSolve.init_cacheval(
67-
alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{<:Number, <:Integer}, b, u,
67+
alg::LUFactorization, A::AbstractSparseArray{<:Number, <:Integer}, b, u,
68+
Pl, Pr,
69+
maxiters::Int, abstol, reltol,
70+
verbose::Bool, assumptions::OperatorAssumptions)
71+
nothing
72+
end
73+
74+
function LinearSolve.init_cacheval(
75+
alg::GenericLUFactorization, A::AbstractSparseArray{<:Number, <:Integer}, b, u,
6876
Pl, Pr,
6977
maxiters::Int, abstol, reltol,
7078
verbose::Bool, assumptions::OperatorAssumptions)
@@ -80,23 +88,23 @@ function LinearSolve.init_cacheval(
8088
end
8189

8290
function LinearSolve.init_cacheval(
83-
alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{Float64, Int64}, b, u,
91+
alg::LUFactorization, A::AbstractSparseArray{Float64, Int64}, b, u,
8492
Pl, Pr,
8593
maxiters::Int, abstol, reltol,
8694
verbose::Bool, assumptions::OperatorAssumptions)
8795
PREALLOCATED_UMFPACK
8896
end
8997

9098
function LinearSolve.init_cacheval(
91-
alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{T, Int64}, b, u,
99+
alg::LUFactorization, A::AbstractSparseArray{T, Int64}, b, u,
92100
Pl, Pr,
93101
maxiters::Int, abstol, reltol,
94102
verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES}
95103
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
96104
end
97105

98106
function LinearSolve.init_cacheval(
99-
alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{T, Int32}, b, u,
107+
alg::LUFactorization, A::AbstractSparseArray{T, Int32}, b, u,
100108
Pl, Pr,
101109
maxiters::Int, abstol, reltol,
102110
verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES}

src/LinearSolve.jl

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ end
140140

141141
const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64}
142142

143+
include("generic_lufact.jl")
143144
include("common.jl")
144145
include("extension_algs.jl")
145146
include("factorization.jl")
@@ -171,28 +172,6 @@ end
171172
@inline _notsuccessful(F) = hasmethod(LinearAlgebra.issuccess, (typeof(F),)) ?
172173
!LinearAlgebra.issuccess(F) : false
173174

174-
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
175-
kwargs...)
176-
quote
177-
if cache.isfresh
178-
fact = do_factorization(alg, cache.A, cache.b, cache.u)
179-
cache.cacheval = fact
180-
181-
# If factorization was not successful, return failure. Don't reset `isfresh`
182-
if _notsuccessful(fact)
183-
return SciMLBase.build_linear_solution(
184-
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
185-
end
186-
187-
cache.isfresh = false
188-
end
189-
190-
y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))),
191-
cache.b)
192-
return SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success)
193-
end
194-
end
195-
196175
# Solver Specific Traits
197176
## Needs Square Matrix
198177
"""

src/factorization.jl

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
1+
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
2+
kwargs...)
3+
quote
4+
if cache.isfresh
5+
fact = do_factorization(alg, cache.A, cache.b, cache.u)
6+
cache.cacheval = fact
7+
8+
# If factorization was not successful, return failure. Don't reset `isfresh`
9+
if _notsuccessful(fact)
10+
return SciMLBase.build_linear_solution(
11+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
12+
end
13+
14+
cache.isfresh = false
15+
end
16+
17+
y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))),
18+
cache.b)
19+
return SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success)
20+
end
21+
end
22+
123
macro get_cacheval(cache, algsym)
224
quote
325
if $(esc(cache)).alg isa DefaultLinearSolver
@@ -8,6 +30,8 @@ macro get_cacheval(cache, algsym)
830
end
931
end
1032

33+
const PREALLOCATED_IPIV = Vector{LinearAlgebra.BlasInt}(undef, 0)
34+
1135
_ldiv!(x, A, b) = ldiv!(x, A, b)
1236

1337
_ldiv!(x, A, b::SVector) = (x .= A \ b)
@@ -41,8 +65,7 @@ function LinearSolve.init_cacheval(
4165
alg::RFLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr,
4266
maxiters::Int,
4367
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
44-
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
45-
PREALLOCATED_LU, ipiv
68+
PREALLOCATED_LU, PREALLOCATED_IPIV
4669
end
4770

4871
function LinearSolve.init_cacheval(alg::RFLUFactorization,
@@ -144,41 +167,85 @@ function do_factorization(alg::LUFactorization, A, b, u)
144167
return fact
145168
end
146169

147-
function do_factorization(alg::GenericLUFactorization, A, b, u)
170+
function init_cacheval(
171+
alg::GenericLUFactorization, A, b, u, Pl, Pr,
172+
maxiters::Int, abstol, reltol, verbose::Bool,
173+
assumptions::OperatorAssumptions)
174+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
175+
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), ipiv
176+
end
177+
178+
function init_cacheval(
179+
alg::GenericLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr,
180+
maxiters::Int, abstol, reltol, verbose::Bool,
181+
assumptions::OperatorAssumptions)
182+
PREALLOCATED_LU, PREALLOCATED_IPIV
183+
end
184+
185+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::GenericLUFactorization;
186+
kwargs...)
187+
A = cache.A
148188
A = convert(AbstractMatrix, A)
149-
fact = LinearAlgebra.generic_lufact!(A, alg.pivot, check = false)
150-
return fact
189+
fact, ipiv = LinearSolve.@get_cacheval(cache, :GenericLUFactorization)
190+
191+
if cache.isfresh
192+
if length(ipiv) != min(size(A)...)
193+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
194+
end
195+
fact = generic_lufact!(A, alg.pivot, ipiv; check = false)
196+
cache.cacheval = (fact, ipiv)
197+
198+
if !LinearAlgebra.issuccess(fact)
199+
return SciMLBase.build_linear_solution(
200+
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure)
201+
end
202+
203+
cache.isfresh = false
204+
end
205+
y = ldiv!(cache.u, LinearSolve.@get_cacheval(cache, :GenericLUFactorization)[1], cache.b)
206+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
151207
end
152208

153209
function init_cacheval(
154-
alg::Union{LUFactorization, GenericLUFactorization}, A, b, u, Pl, Pr,
210+
alg::LUFactorization, A, b, u, Pl, Pr,
155211
maxiters::Int, abstol, reltol, verbose::Bool,
156212
assumptions::OperatorAssumptions)
157213
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
158214
end
159215

160-
function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
216+
function init_cacheval(alg::LUFactorization,
161217
A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
162218
verbose::Bool, assumptions::OperatorAssumptions)
163219
error_no_cudss_lu(A)
164-
if alg isa LUFactorization
165-
return lu(A; check = false)
166-
else
167-
A isa GPUArraysCore.AnyGPUArray && return nothing
168-
return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check = false)
169-
end
220+
return lu(A; check = false)
221+
end
222+
223+
function init_cacheval(alg::GenericLUFactorization,
224+
A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
225+
verbose::Bool, assumptions::OperatorAssumptions)
226+
error_no_cudss_lu(A)
227+
A isa GPUArraysCore.AnyGPUArray && return nothing
228+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
229+
return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check = false), ipiv
170230
end
171231

172232
const PREALLOCATED_LU = ArrayInterface.lu_instance(rand(1, 1))
173233

174-
function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
234+
function init_cacheval(alg::LUFactorization,
175235
A::Matrix{Float64}, b, u, Pl, Pr,
176236
maxiters::Int, abstol, reltol, verbose::Bool,
177237
assumptions::OperatorAssumptions)
178238
PREALLOCATED_LU
179239
end
180240

181-
function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
241+
function init_cacheval(alg::LUFactorization,
242+
A::AbstractSciMLOperator, b, u, Pl, Pr,
243+
maxiters::Int, abstol, reltol, verbose::Bool,
244+
assumptions::OperatorAssumptions)
245+
nothing
246+
end
247+
248+
function init_cacheval(alg::GenericLUFactorization,
182249
A::AbstractSciMLOperator, b, u, Pl, Pr,
183250
maxiters::Int, abstol, reltol, verbose::Bool,
184251
assumptions::OperatorAssumptions)

src/generic_lufact.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# From LinearAlgebra.lu.jl
2+
# Modified to be non-allocating
3+
@static if VERSION < v"1.11"
4+
function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = LinearAlgebra.lupivottype(T),
5+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...));
6+
check::Bool = true, allowsingular::Bool = false) where {T}
7+
check && LinearAlgebra.LAPACK.chkfinite(A)
8+
# Extract values
9+
m, n = size(A)
10+
minmn = min(m,n)
11+
12+
# Initialize variables
13+
info = 0
14+
15+
@inbounds begin
16+
for k = 1:minmn
17+
# find index max
18+
kp = k
19+
if pivot === LinearAlgebra.RowMaximum() && k < m
20+
amax = abs(A[k, k])
21+
for i = k+1:m
22+
absi = abs(A[i,k])
23+
if absi > amax
24+
kp = i
25+
amax = absi
26+
end
27+
end
28+
elseif pivot === LinearAlgebra.RowNonZero()
29+
for i = k:m
30+
if !iszero(A[i,k])
31+
kp = i
32+
break
33+
end
34+
end
35+
end
36+
ipiv[k] = kp
37+
if !iszero(A[kp,k])
38+
if k != kp
39+
# Interchange
40+
for i = 1:n
41+
tmp = A[k,i]
42+
A[k,i] = A[kp,i]
43+
A[kp,i] = tmp
44+
end
45+
end
46+
# Scale first column
47+
Akkinv = inv(A[k,k])
48+
for i = k+1:m
49+
A[i,k] *= Akkinv
50+
end
51+
elseif info == 0
52+
info = k
53+
end
54+
# Update the rest
55+
for j = k+1:n
56+
for i = k+1:m
57+
A[i,j] -= A[i,k]*A[k,j]
58+
end
59+
end
60+
end
61+
end
62+
check && LinearAlgebra.checknonsingular(info, pivot)
63+
return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info))
64+
end
65+
elseif VERSION < v"1.13"
66+
function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = LinearAlgebra.lupivottype(T),
67+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...));
68+
check::Bool = true, allowsingular::Bool = false) where {T}
69+
check && LAPACK.chkfinite(A)
70+
# Extract values
71+
m, n = size(A)
72+
minmn = min(m,n)
73+
74+
# Initialize variables
75+
info = 0
76+
77+
@inbounds begin
78+
for k = 1:minmn
79+
# find index max
80+
kp = k
81+
if pivot === LinearAlgebra.RowMaximum() && k < m
82+
amax = abs(A[k, k])
83+
for i = k+1:m
84+
absi = abs(A[i,k])
85+
if absi > amax
86+
kp = i
87+
amax = absi
88+
end
89+
end
90+
elseif pivot === LinearAlgebra.RowNonZero()
91+
for i = k:m
92+
if !iszero(A[i,k])
93+
kp = i
94+
break
95+
end
96+
end
97+
end
98+
ipiv[k] = kp
99+
if !iszero(A[kp,k])
100+
if k != kp
101+
# Interchange
102+
for i = 1:n
103+
tmp = A[k,i]
104+
A[k,i] = A[kp,i]
105+
A[kp,i] = tmp
106+
end
107+
end
108+
# Scale first column
109+
Akkinv = inv(A[k,k])
110+
for i = k+1:m
111+
A[i,k] *= Akkinv
112+
end
113+
elseif info == 0
114+
info = k
115+
end
116+
# Update the rest
117+
for j = k+1:n
118+
for i = k+1:m
119+
A[i,j] -= A[i,k]*A[k,j]
120+
end
121+
end
122+
end
123+
end
124+
if pivot === LinearAlgebra.NoPivot()
125+
# Use a negative value to distinguish a failed factorization (zero in pivot
126+
# position during unpivoted LU) from a valid but rank-deficient factorization
127+
info = -info
128+
end
129+
check && LinearAlgebra._check_lu_success(info, allowsingular)
130+
return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info))
131+
end
132+
else
133+
generic_lufact!(args...; kwargs...) = LinearAlgebra.generic_lufact!(args...; kwargs...)
134+
end

0 commit comments

Comments
 (0)