diff --git a/Project.toml b/Project.toml index 1ca75401f..797de9d01 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" +CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" CUSOLVERRF = "a8cc9031-bad2-4722-94f5-40deabb4245c" @@ -55,6 +56,7 @@ LinearSolveAMDGPUExt = "AMDGPU" LinearSolveBLISExt = ["blis_jll", "LAPACK_jll"] LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" +LinearSolveCliqueTreesExt = ["CliqueTrees", "SparseArrays"] LinearSolveCUDAExt = "CUDA" LinearSolveCUDSSExt = "CUDSS" LinearSolveCUSOLVERRFExt = ["CUSOLVERRF", "SparseArrays"] @@ -83,6 +85,7 @@ CUDA = "5" CUDSS = "0.4" CUSOLVERRF = "0.2.6" ChainRulesCore = "1.22" +CliqueTrees = "1.11.0" ConcreteStructs = "0.2.3" DocStringExtensions = "0.9.3" EnumX = "1.0.4" @@ -136,6 +139,7 @@ AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" +CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641" @@ -163,4 +167,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "FastLapackInterface", "SparseArrays", "ExplicitImports"] +test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote", "RecursiveFactorization", "Sparspak", "CliqueTrees", "FastLapackInterface", "SparseArrays", "ExplicitImports"] diff --git a/docs/src/solvers/solvers.md b/docs/src/solvers/solvers.md index ed155d8b3..0d2b31ac4 100644 --- a/docs/src/solvers/solvers.md +++ b/docs/src/solvers/solvers.md @@ -177,6 +177,16 @@ UMFPACKFactorization SparspakFactorization ``` +### CliqueTrees.jl + +!!! note + + Using this solver requires adding the package CliqueTrees.jl, i.e. `using CliqueTrees` + +```@docs +CliqueTreesFactorization +``` + ### Krylov.jl ```@docs diff --git a/ext/LinearSolveCliqueTreesExt.jl b/ext/LinearSolveCliqueTreesExt.jl new file mode 100644 index 000000000..1a9ddf66c --- /dev/null +++ b/ext/LinearSolveCliqueTreesExt.jl @@ -0,0 +1,65 @@ +module LinearSolveCliqueTreesExt + +using CliqueTrees: EliminationAlgorithm, SupernodeType, DEFAULT_ELIMINATION_ALGORITHM, + DEFAULT_SUPERNODE_TYPE, symbolic, cholinit, lininit, cholesky!, linsolve! +using LinearSolve +using SparseArrays + +function LinearSolve.CliqueTreesFactorization(; + alg::A=DEFAULT_ELIMINATION_ALGORITHM, + snd::S=DEFAULT_SUPERNODE_TYPE, + reuse_symbolic::Bool=true, + ) where {A <: EliminationAlgorithm, S <: SupernodeType} + return CliqueTreesFactorization{A, S}(alg, snd, reuse_symbolic) +end + +function LinearSolve.init_cacheval( + alg::CliqueTreesFactorization, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, + reltol, verbose::Bool, assumptions::OperatorAssumptions) + symbfact = symbolic(A; alg=alg.alg, snd=alg.snd) + cholfact, cholwork = cholinit(A, symbfact) + linwork = lininit(1, cholfact) + return (cholfact, cholwork, linwork) +end + +function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CliqueTreesFactorization; kwargs...) + A = cache.A + u = cache.u + b = cache.b + + if cache.isfresh + if isnothing(cache.cacheval) || !alg.reuse_symbolic + symbfact = symbolic(A; alg=alg.alg, snd=alg.snd) + cholfact, cholwork = cholinit(A, symbfact) + linwork = lininit(1, cholfact) + cache.cacheval = (cholfact, cholwork, linwork) + end + + cholfact, cholwork, linwork = cache.cacheval + cholesky!(cholfact, cholwork, A) + cache.isfresh = false + end + + cholfact, cholwork, linwork = cache.cacheval + linsolve!(copyto!(u, b), linwork, cholfact, Val(false)) + return SciMLBase.build_linear_solution(alg, u, nothing, cache) +end + +LinearSolve.PrecompileTools.@compile_workload begin + A = sparse([ + 3 1 0 0 0 0 0 0 + 1 3 1 0 0 2 0 0 + 0 1 3 1 0 1 2 1 + 0 0 1 3 0 0 0 0 + 0 0 0 0 3 1 1 0 + 0 2 1 0 1 3 0 0 + 0 0 2 0 1 0 3 1 + 0 0 1 0 0 0 1 3 + ]) + + b = rand(8) + prob = LinearProblem(A, b) + sol = solve(prob, CliqueTreesFactorization()) +end + +end diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index ac2b2d148..5afa07ad9 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -241,7 +241,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization, UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization, SparspakFactorization, DiagonalFactorization, CholeskyFactorization, BunchKaufmanFactorization, CHOLMODFactorization, LDLtFactorization, - CUSOLVERRFFactorization + CUSOLVERRFFactorization, CliqueTreesFactorization export LinearSolveFunction, DirectLdiv! diff --git a/src/factorization.jl b/src/factorization.jl index 4b3e946a9..6cd6cd571 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -1158,6 +1158,35 @@ function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr, nothing end +## CliqueTreesFactorization is here since it's MIT licensed, not GPL + +""" + CliqueTreesFactorization( + alg = CliqueTrees.DEFAULT_ELIMINATION_ALGORITHM, + snd = CliqueTrees.DEFAULT_SUPERNODE_TYPE, + reuse_symbolic = true, + ) + +The sparse Cholesky factorization algorithm implemented in CliqueTrees.jl. +The implementation is pure-Julia and accepts arbitrary numeric types. It is +somewhat slower than CHOLMOD. +""" +struct CliqueTreesFactorization{A, S} <: AbstractSparseFactorization + alg::A + snd::S + reuse_symbolic::Bool +end + +function init_cacheval(::CliqueTreesFactorization, ::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + nothing +end + +function init_cacheval(::CliqueTreesFactorization, ::StaticArray, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + nothing +end + for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization), InteractiveUtils.subtypes(AbstractSparseFactorization)) @eval function init_cacheval(alg::$alg, A::MatrixOperator, b, u, Pl, Pr, diff --git a/test/basictests.jl b/test/basictests.jl index c1ff0c5be..8aa7547a9 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -2,7 +2,7 @@ using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff using SciMLOperators, RecursiveFactorization, Sparspak, FastLapackInterface using IterativeSolvers, KrylovKit, MKL_jll, KrylovPreconditioners using Test -import Random +import CliqueTrees, Random # Try to load BLIS extension try @@ -205,6 +205,58 @@ end test_interface(SparspakFactorization(), prob1, prob2) end + @testset "CliqueTrees Factorization (Float64)" begin + A1 = sparse(A / 1) + b1 = rand(n) + x1 = zero(b) + A2 = sparse(A / 2) + b2 = rand(n) + x2 = zero(b) + + prob1 = LinearProblem(A1, b1; u0 = x1) + prob2 = LinearProblem(A2, b2; u0 = x2) + test_interface(CliqueTreesFactorization(), prob1, prob2) + end + + @testset "CliqueTrees Factorization (Float64x1)" begin + A1 = sparse(A / 1) .|> Float64x1 + b1 = rand(n) .|> Float64x1 + x1 = zero(b) .|> Float64x1 + A2 = sparse(A / 2) .|> Float64x1 + b2 = rand(n) .|> Float64x1 + x2 = zero(b) .|> Float64x1 + + prob1 = LinearProblem(A1, b1; u0 = x1) + prob2 = LinearProblem(A2, b2; u0 = x2) + test_interface(CliqueTreesFactorization(), prob1, prob2) + end + + @testset "CliqueTrees Factorization (Float64x2)" begin + A1 = sparse(A / 1) .|> Float64x2 + b1 = rand(n) .|> Float64x2 + x1 = zero(b) .|> Float64x2 + A2 = sparse(A / 2) .|> Float64x2 + b2 = rand(n) .|> Float64x2 + x2 = zero(b) .|> Float64x2 + + prob1 = LinearProblem(A1, b1; u0 = x1) + prob2 = LinearProblem(A2, b2; u0 = x2) + test_interface(CliqueTreesFactorization(), prob1, prob2) + end + + @testset "CliqueTrees Factorization (Dual64)" begin + A1 = sparse(A / 1) .|> Dual64 + b1 = rand(n) .|> Dual64 + x1 = zero(b) .|> Dual64 + A2 = sparse(A / 2) .|> Dual64 + b2 = rand(n) .|> Dual64 + x2 = zero(b) .|> Dual64 + + prob1 = LinearProblem(A1, b1; u0 = x1) + prob2 = LinearProblem(A2, b2; u0 = x2) + test_interface(CliqueTreesFactorization(), prob1, prob2) + end + @testset "FastLAPACK Factorizations" begin A1 = A / 1 b1 = rand(n)