Skip to content

Pure-Julia Sparse Cholesky #721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
CUSOLVERRF = "a8cc9031-bad2-4722-94f5-40deabb4245c"
Expand All @@ -55,6 +56,7 @@ LinearSolveAMDGPUExt = "AMDGPU"
LinearSolveBLISExt = ["blis_jll", "LAPACK_jll"]
LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCliqueTreesExt = ["CliqueTrees", "SparseArrays"]
LinearSolveCUDAExt = "CUDA"
LinearSolveCUDSSExt = "CUDSS"
LinearSolveCUSOLVERRFExt = ["CUSOLVERRF", "SparseArrays"]
Expand Down Expand Up @@ -83,6 +85,7 @@ CUDA = "5"
CUDSS = "0.4"
CUSOLVERRF = "0.2.6"
ChainRulesCore = "1.22"
CliqueTrees = "1.11.0"
ConcreteStructs = "0.2.3"
DocStringExtensions = "0.9.3"
EnumX = "1.0.4"
Expand Down Expand Up @@ -136,6 +139,7 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
Expand Down Expand Up @@ -163,4 +167,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "FastLapackInterface", "SparseArrays", "ExplicitImports"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "CliqueTrees", "FastLapackInterface", "SparseArrays", "ExplicitImports"]
10 changes: 10 additions & 0 deletions docs/src/solvers/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ UMFPACKFactorization
SparspakFactorization
```

### CliqueTrees.jl

!!! note

Using this solver requires adding the package CliqueTrees.jl, i.e. `using CliqueTrees`

```@docs
CliqueTreesFactorization
```

### Krylov.jl

```@docs
Expand Down
65 changes: 65 additions & 0 deletions ext/LinearSolveCliqueTreesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module LinearSolveCliqueTreesExt

using CliqueTrees: EliminationAlgorithm, SupernodeType, DEFAULT_ELIMINATION_ALGORITHM,
DEFAULT_SUPERNODE_TYPE, symbolic, cholinit, lininit, cholesky!, linsolve!
using LinearSolve
using SparseArrays

function LinearSolve.CliqueTreesFactorization(;
alg::A=DEFAULT_ELIMINATION_ALGORITHM,
snd::S=DEFAULT_SUPERNODE_TYPE,
reuse_symbolic::Bool=true,
) where {A <: EliminationAlgorithm, S <: SupernodeType}
return CliqueTreesFactorization{A, S}(alg, snd, reuse_symbolic)
end

function LinearSolve.init_cacheval(
alg::CliqueTreesFactorization, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol,
reltol, verbose::Bool, assumptions::OperatorAssumptions)
symbfact = symbolic(A; alg=alg.alg, snd=alg.snd)
cholfact, cholwork = cholinit(A, symbfact)
linwork = lininit(1, cholfact)
return (cholfact, cholwork, linwork)
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CliqueTreesFactorization; kwargs...)
A = cache.A
u = cache.u
b = cache.b

if cache.isfresh
if isnothing(cache.cacheval) || !alg.reuse_symbolic
symbfact = symbolic(A; alg=alg.alg, snd=alg.snd)
cholfact, cholwork = cholinit(A, symbfact)
linwork = lininit(1, cholfact)
cache.cacheval = (cholfact, cholwork, linwork)
end

cholfact, cholwork, linwork = cache.cacheval
cholesky!(cholfact, cholwork, A)
cache.isfresh = false
end

cholfact, cholwork, linwork = cache.cacheval
linsolve!(copyto!(u, b), linwork, cholfact, Val(false))
return SciMLBase.build_linear_solution(alg, u, nothing, cache)
end

LinearSolve.PrecompileTools.@compile_workload begin
A = sparse([
3 1 0 0 0 0 0 0
1 3 1 0 0 2 0 0
0 1 3 1 0 1 2 1
0 0 1 3 0 0 0 0
0 0 0 0 3 1 1 0
0 2 1 0 1 3 0 0
0 0 2 0 1 0 3 1
0 0 1 0 0 0 1 3
])

b = rand(8)
prob = LinearProblem(A, b)
sol = solve(prob, CliqueTreesFactorization())
end

end
2 changes: 1 addition & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
SparspakFactorization, DiagonalFactorization, CholeskyFactorization,
BunchKaufmanFactorization, CHOLMODFactorization, LDLtFactorization,
CUSOLVERRFFactorization
CUSOLVERRFFactorization, CliqueTreesFactorization

export LinearSolveFunction, DirectLdiv!

Expand Down
29 changes: 29 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,35 @@ function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr,
nothing
end

## CliqueTreesFactorization is here since it's MIT licensed, not GPL

"""
CliqueTreesFactorization(
alg = CliqueTrees.DEFAULT_ELIMINATION_ALGORITHM,
snd = CliqueTrees.DEFAULT_SUPERNODE_TYPE,
reuse_symbolic = true,
)

The sparse Cholesky factorization algorithm implemented in CliqueTrees.jl.
The implementation is pure-Julia and accepts arbitrary numeric types. It is
somewhat slower than CHOLMOD.
"""
struct CliqueTreesFactorization{A, S} <: AbstractSparseFactorization
alg::A
snd::S
reuse_symbolic::Bool
end

function init_cacheval(::CliqueTreesFactorization, ::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function init_cacheval(::CliqueTreesFactorization, ::StaticArray, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
InteractiveUtils.subtypes(AbstractSparseFactorization))
@eval function init_cacheval(alg::$alg, A::MatrixOperator, b, u, Pl, Pr,
Expand Down
54 changes: 53 additions & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
using SciMLOperators, RecursiveFactorization, Sparspak, FastLapackInterface
using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners
using Test
import Random
import CliqueTrees, Random

# Try to load BLIS extension
try
Expand Down Expand Up @@ -205,6 +205,58 @@ end
test_interface(SparspakFactorization(), prob1, prob2)
end

@testset "CliqueTrees Factorization (Float64)" begin
A1 = sparse(A / 1)
b1 = rand(n)
x1 = zero(b)
A2 = sparse(A / 2)
b2 = rand(n)
x2 = zero(b)

prob1 = LinearProblem(A1, b1; u0 = x1)
prob2 = LinearProblem(A2, b2; u0 = x2)
test_interface(CliqueTreesFactorization(), prob1, prob2)
end

@testset "CliqueTrees Factorization (Float64x1)" begin
A1 = sparse(A / 1) .|> Float64x1
b1 = rand(n) .|> Float64x1
x1 = zero(b) .|> Float64x1
A2 = sparse(A / 2) .|> Float64x1
b2 = rand(n) .|> Float64x1
x2 = zero(b) .|> Float64x1

prob1 = LinearProblem(A1, b1; u0 = x1)
prob2 = LinearProblem(A2, b2; u0 = x2)
test_interface(CliqueTreesFactorization(), prob1, prob2)
end

@testset "CliqueTrees Factorization (Float64x2)" begin
A1 = sparse(A / 1) .|> Float64x2
b1 = rand(n) .|> Float64x2
x1 = zero(b) .|> Float64x2
A2 = sparse(A / 2) .|> Float64x2
b2 = rand(n) .|> Float64x2
x2 = zero(b) .|> Float64x2

prob1 = LinearProblem(A1, b1; u0 = x1)
prob2 = LinearProblem(A2, b2; u0 = x2)
test_interface(CliqueTreesFactorization(), prob1, prob2)
end

@testset "CliqueTrees Factorization (Dual64)" begin
A1 = sparse(A / 1) .|> Dual64
b1 = rand(n) .|> Dual64
x1 = zero(b) .|> Dual64
A2 = sparse(A / 2) .|> Dual64
b2 = rand(n) .|> Dual64
x2 = zero(b) .|> Dual64

prob1 = LinearProblem(A1, b1; u0 = x1)
prob2 = LinearProblem(A2, b2; u0 = x2)
test_interface(CliqueTreesFactorization(), prob1, prob2)
end

@testset "FastLAPACK Factorizations" begin
A1 = A / 1
b1 = rand(n)
Expand Down
Loading