Skip to content

Commit 736a9cd

Browse files
authored
Pure-Julia Sparse Cholesky (#721)
* Add CliqueTrees Cholesky factorization. * Update documentation. * Address comments.
1 parent d3aa580 commit 736a9cd

File tree

6 files changed

+163
-3
lines changed

6 files changed

+163
-3
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3131
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3232
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3333
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
34+
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
3435
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3536
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
3637
CUSOLVERRF = "a8cc9031-bad2-4722-94f5-40deabb4245c"
@@ -55,6 +56,7 @@ LinearSolveAMDGPUExt = "AMDGPU"
5556
LinearSolveBLISExt = ["blis_jll", "LAPACK_jll"]
5657
LinearSolveBandedMatricesExt = "BandedMatrices"
5758
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
59+
LinearSolveCliqueTreesExt = ["CliqueTrees", "SparseArrays"]
5860
LinearSolveCUDAExt = "CUDA"
5961
LinearSolveCUDSSExt = "CUDSS"
6062
LinearSolveCUSOLVERRFExt = ["CUSOLVERRF", "SparseArrays"]
@@ -83,6 +85,7 @@ CUDA = "5"
8385
CUDSS = "0.4"
8486
CUSOLVERRF = "0.2.6"
8587
ChainRulesCore = "1.22"
88+
CliqueTrees = "1.11.0"
8689
ConcreteStructs = "0.2.3"
8790
DocStringExtensions = "0.9.3"
8891
EnumX = "1.0.4"
@@ -136,6 +139,7 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
136139
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
137140
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
138141
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
142+
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
139143
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
140144
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
141145
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
@@ -163,4 +167,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
163167
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
164168

165169
[targets]
166-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "FastLapackInterface", "SparseArrays", "ExplicitImports"]
170+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "CliqueTrees", "FastLapackInterface", "SparseArrays", "ExplicitImports"]

docs/src/solvers/solvers.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,16 @@ UMFPACKFactorization
177177
SparspakFactorization
178178
```
179179

180+
### CliqueTrees.jl
181+
182+
!!! note
183+
184+
Using this solver requires adding the package CliqueTrees.jl, i.e. `using CliqueTrees`
185+
186+
```@docs
187+
CliqueTreesFactorization
188+
```
189+
180190
### Krylov.jl
181191

182192
```@docs

ext/LinearSolveCliqueTreesExt.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
module LinearSolveCliqueTreesExt
2+
3+
using CliqueTrees: EliminationAlgorithm, SupernodeType, DEFAULT_ELIMINATION_ALGORITHM,
4+
DEFAULT_SUPERNODE_TYPE, symbolic, cholinit, lininit, cholesky!, linsolve!
5+
using LinearSolve
6+
using SparseArrays
7+
8+
function LinearSolve.CliqueTreesFactorization(;
9+
alg::A=DEFAULT_ELIMINATION_ALGORITHM,
10+
snd::S=DEFAULT_SUPERNODE_TYPE,
11+
reuse_symbolic::Bool=true,
12+
) where {A <: EliminationAlgorithm, S <: SupernodeType}
13+
return CliqueTreesFactorization{A, S}(alg, snd, reuse_symbolic)
14+
end
15+
16+
function LinearSolve.init_cacheval(
17+
alg::CliqueTreesFactorization, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol,
18+
reltol, verbose::Bool, assumptions::OperatorAssumptions)
19+
symbfact = symbolic(A; alg=alg.alg, snd=alg.snd)
20+
cholfact, cholwork = cholinit(A, symbfact)
21+
linwork = lininit(1, cholfact)
22+
return (cholfact, cholwork, linwork)
23+
end
24+
25+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CliqueTreesFactorization; kwargs...)
26+
A = cache.A
27+
u = cache.u
28+
b = cache.b
29+
30+
if cache.isfresh
31+
if isnothing(cache.cacheval) || !alg.reuse_symbolic
32+
symbfact = symbolic(A; alg=alg.alg, snd=alg.snd)
33+
cholfact, cholwork = cholinit(A, symbfact)
34+
linwork = lininit(1, cholfact)
35+
cache.cacheval = (cholfact, cholwork, linwork)
36+
end
37+
38+
cholfact, cholwork, linwork = cache.cacheval
39+
cholesky!(cholfact, cholwork, A)
40+
cache.isfresh = false
41+
end
42+
43+
cholfact, cholwork, linwork = cache.cacheval
44+
linsolve!(copyto!(u, b), linwork, cholfact, Val(false))
45+
return SciMLBase.build_linear_solution(alg, u, nothing, cache)
46+
end
47+
48+
LinearSolve.PrecompileTools.@compile_workload begin
49+
A = sparse([
50+
3 1 0 0 0 0 0 0
51+
1 3 1 0 0 2 0 0
52+
0 1 3 1 0 1 2 1
53+
0 0 1 3 0 0 0 0
54+
0 0 0 0 3 1 1 0
55+
0 2 1 0 1 3 0 0
56+
0 0 2 0 1 0 3 1
57+
0 0 1 0 0 0 1 3
58+
])
59+
60+
b = rand(8)
61+
prob = LinearProblem(A, b)
62+
sol = solve(prob, CliqueTreesFactorization())
63+
end
64+
65+
end

src/LinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
241241
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
242242
SparspakFactorization, DiagonalFactorization, CholeskyFactorization,
243243
BunchKaufmanFactorization, CHOLMODFactorization, LDLtFactorization,
244-
CUSOLVERRFFactorization
244+
CUSOLVERRFFactorization, CliqueTreesFactorization
245245

246246
export LinearSolveFunction, DirectLdiv!
247247

src/factorization.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,35 @@ function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr,
11581158
nothing
11591159
end
11601160

1161+
## CliqueTreesFactorization is here since it's MIT licensed, not GPL
1162+
1163+
"""
1164+
CliqueTreesFactorization(
1165+
alg = CliqueTrees.DEFAULT_ELIMINATION_ALGORITHM,
1166+
snd = CliqueTrees.DEFAULT_SUPERNODE_TYPE,
1167+
reuse_symbolic = true,
1168+
)
1169+
1170+
The sparse Cholesky factorization algorithm implemented in CliqueTrees.jl.
1171+
The implementation is pure-Julia and accepts arbitrary numeric types. It is
1172+
somewhat slower than CHOLMOD.
1173+
"""
1174+
struct CliqueTreesFactorization{A, S} <: AbstractSparseFactorization
1175+
alg::A
1176+
snd::S
1177+
reuse_symbolic::Bool
1178+
end
1179+
1180+
function init_cacheval(::CliqueTreesFactorization, ::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr,
1181+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1182+
nothing
1183+
end
1184+
1185+
function init_cacheval(::CliqueTreesFactorization, ::StaticArray, b, u, Pl, Pr,
1186+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1187+
nothing
1188+
end
1189+
11611190
for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
11621191
InteractiveUtils.subtypes(AbstractSparseFactorization))
11631192
@eval function init_cacheval(alg::$alg, A::MatrixOperator, b, u, Pl, Pr,

test/basictests.jl

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
22
using SciMLOperators, RecursiveFactorization, Sparspak, FastLapackInterface
33
using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners
44
using Test
5-
import Random
5+
import CliqueTrees, Random
66

77
# Try to load BLIS extension
88
try
@@ -205,6 +205,58 @@ end
205205
test_interface(SparspakFactorization(), prob1, prob2)
206206
end
207207

208+
@testset "CliqueTrees Factorization (Float64)" begin
209+
A1 = sparse(A / 1)
210+
b1 = rand(n)
211+
x1 = zero(b)
212+
A2 = sparse(A / 2)
213+
b2 = rand(n)
214+
x2 = zero(b)
215+
216+
prob1 = LinearProblem(A1, b1; u0 = x1)
217+
prob2 = LinearProblem(A2, b2; u0 = x2)
218+
test_interface(CliqueTreesFactorization(), prob1, prob2)
219+
end
220+
221+
@testset "CliqueTrees Factorization (Float64x1)" begin
222+
A1 = sparse(A / 1) .|> Float64x1
223+
b1 = rand(n) .|> Float64x1
224+
x1 = zero(b) .|> Float64x1
225+
A2 = sparse(A / 2) .|> Float64x1
226+
b2 = rand(n) .|> Float64x1
227+
x2 = zero(b) .|> Float64x1
228+
229+
prob1 = LinearProblem(A1, b1; u0 = x1)
230+
prob2 = LinearProblem(A2, b2; u0 = x2)
231+
test_interface(CliqueTreesFactorization(), prob1, prob2)
232+
end
233+
234+
@testset "CliqueTrees Factorization (Float64x2)" begin
235+
A1 = sparse(A / 1) .|> Float64x2
236+
b1 = rand(n) .|> Float64x2
237+
x1 = zero(b) .|> Float64x2
238+
A2 = sparse(A / 2) .|> Float64x2
239+
b2 = rand(n) .|> Float64x2
240+
x2 = zero(b) .|> Float64x2
241+
242+
prob1 = LinearProblem(A1, b1; u0 = x1)
243+
prob2 = LinearProblem(A2, b2; u0 = x2)
244+
test_interface(CliqueTreesFactorization(), prob1, prob2)
245+
end
246+
247+
@testset "CliqueTrees Factorization (Dual64)" begin
248+
A1 = sparse(A / 1) .|> Dual64
249+
b1 = rand(n) .|> Dual64
250+
x1 = zero(b) .|> Dual64
251+
A2 = sparse(A / 2) .|> Dual64
252+
b2 = rand(n) .|> Dual64
253+
x2 = zero(b) .|> Dual64
254+
255+
prob1 = LinearProblem(A1, b1; u0 = x1)
256+
prob2 = LinearProblem(A2, b2; u0 = x2)
257+
test_interface(CliqueTreesFactorization(), prob1, prob2)
258+
end
259+
208260
@testset "FastLAPACK Factorizations" begin
209261
A1 = A / 1
210262
b1 = rand(n)

0 commit comments

Comments
 (0)