Skip to content

Commit 9cfadfc

Browse files
Symbolic Factorization Reuse in the standard LUFactorization
This should be sufficient for CUDSS.jl to be optimally used as well
1 parent d050e01 commit 9cfadfc

File tree

3 files changed

+43
-12
lines changed

3 files changed

+43
-12
lines changed

src/LinearSolve.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ _isidentity_struct(::SciMLOperators.IdentityOperator) = true
8282
# Dispatch Friendly way to check if an extension is loaded
8383
__is_extension_loaded(::Val) = false
8484

85+
# Check if a sparsity pattern has changed
86+
pattern_changed(fact, A) = false
87+
8588
function _fast_sym_givens! end
8689

8790
# Code

src/factorization.jl

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,15 @@ Julia's built in `lu`. Equivalent to calling `lu!(A)`
4949
- pivot: The choice of pivoting. Defaults to `LinearAlgebra.RowMaximum()`. The other choice is
5050
`LinearAlgebra.NoPivot()`.
5151
"""
52-
struct LUFactorization{P} <: AbstractFactorization
53-
pivot::P
52+
Base.@kwdef struct LUFactorization{P} <: AbstractFactorization
53+
pivot::P = LinearAlgebra.RowMaximum()
54+
reuse_symbolic::Bool = true
55+
check_pattern::Bool = true # Check factorization re-use
5456
end
5557

58+
# Legacy dispatch
59+
LUFactorization(pivot) = LUFactorization(;pivot=RowMaximum())
60+
5661
"""
5762
`GenericLUFactorization(pivot=LinearAlgebra.RowMaximum())`
5863
@@ -69,10 +74,33 @@ struct GenericLUFactorization{P} <: AbstractFactorization
6974
pivot::P
7075
end
7176

72-
LUFactorization() = LUFactorization(RowMaximum())
73-
7477
GenericLUFactorization() = GenericLUFactorization(RowMaximum())
7578

79+
function SciMLBase.solve!(cache::LinearCache, alg::LUFactorization; kwargs...)
80+
A = cache.A
81+
A = convert(AbstractMatrix, A)
82+
if cache.isfresh
83+
cacheval = @get_cacheval(cache, :LUFactorization)
84+
if A isa AbstractSparseMatrix && alg.reuse_symbolic
85+
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
86+
# If SparseMatrixCSC, check if the pattern has changed
87+
if alg.check_pattern && pattern_changed(cacheval, A)
88+
fact = lu(A, check = false)
89+
else
90+
fact = lu!(cacheval, A, check = false)
91+
end
92+
else
93+
fact = lu(A, check = false)
94+
end
95+
cache.cacheval = fact
96+
cache.isfresh = false
97+
end
98+
99+
F = @get_cacheval(cache, :LUFactorization)
100+
y = ldiv!(cache.u, F, cache.b)
101+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
102+
end
103+
76104
function do_factorization(alg::LUFactorization, A, b, u)
77105
A = convert(AbstractMatrix, A)
78106
if A isa AbstractSparseMatrixCSC
@@ -775,10 +803,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::UMFPACKFactorization; kwargs.
775803
cacheval = @get_cacheval(cache, :UMFPACKFactorization)
776804
if alg.reuse_symbolic
777805
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
778-
if alg.check_pattern && !(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
779-
cacheval.colptr &&
780-
SparseArrays.decrement(SparseArrays.getrowval(A)) ==
781-
cacheval.rowval)
806+
if alg.check_pattern && pattern_changed(cacheval, A)
782807
fact = lu(
783808
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
784809
nonzeros(A)),
@@ -856,10 +881,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KLUFactorization; kwargs...)
856881
if cache.isfresh
857882
cacheval = @get_cacheval(cache, :KLUFactorization)
858883
if alg.reuse_symbolic
859-
if alg.check_pattern && !(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
860-
cacheval.colptr &&
861-
SparseArrays.decrement(SparseArrays.getrowval(A)) ==
862-
cacheval.rowval)
884+
if alg.check_pattern && pattern_changed(cacheval, A)
863885
fact = KLU.klu(
864886
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
865887
nonzeros(A)),

src/factorization_sparse.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,9 @@ function _ldiv!(::SVector,
2727
b::SVector)
2828
(A \ b)
2929
end
30+
31+
function pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
32+
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
33+
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
34+
fact.rowval)
35+
end

0 commit comments

Comments
 (0)