diff --git a/ext/LinearSolveCliqueTreesExt.jl b/ext/LinearSolveCliqueTreesExt.jl index 1a9ddf66c..4c4530baf 100644 --- a/ext/LinearSolveCliqueTreesExt.jl +++ b/ext/LinearSolveCliqueTreesExt.jl @@ -1,22 +1,29 @@ module LinearSolveCliqueTreesExt -using CliqueTrees: EliminationAlgorithm, SupernodeType, DEFAULT_ELIMINATION_ALGORITHM, - DEFAULT_SUPERNODE_TYPE, symbolic, cholinit, lininit, cholesky!, linsolve! +using CliqueTrees: symbolic, cholinit, lininit, cholesky!, linsolve! using LinearSolve using SparseArrays -function LinearSolve.CliqueTreesFactorization(; - alg::A=DEFAULT_ELIMINATION_ALGORITHM, - snd::S=DEFAULT_SUPERNODE_TYPE, - reuse_symbolic::Bool=true, - ) where {A <: EliminationAlgorithm, S <: SupernodeType} - return CliqueTreesFactorization{A, S}(alg, snd, reuse_symbolic) +function _symbolic(A::AbstractMatrix, alg::CliqueTreesFactorization) + return symbolic(A; alg=alg.alg, snd=alg.snd) +end + +function _symbolic(A::AbstractMatrix, alg::CliqueTreesFactorization{Nothing}) + return symbolic(A; snd=alg.snd) +end + +function _symbolic(A::AbstractMatrix, alg::CliqueTreesFactorization{<:Any, Nothing}) + return symbolic(A; alg=alg.alg) +end + +function _symbolic(A::AbstractMatrix, alg::CliqueTreesFactorization{Nothing, Nothing}) + return symbolic(A) end function LinearSolve.init_cacheval( alg::CliqueTreesFactorization, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - symbfact = symbolic(A; alg=alg.alg, snd=alg.snd) + symbfact = _symbolic(A, alg) cholfact, cholwork = cholinit(A, symbfact) linwork = lininit(1, cholfact) return (cholfact, cholwork, linwork) @@ -29,7 +36,7 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CliqueTreesFactor if cache.isfresh if isnothing(cache.cacheval) || !alg.reuse_symbolic - symbfact = symbolic(A; alg=alg.alg, snd=alg.snd) + symbfact = _symbolic(A, alg) cholfact, cholwork = cholinit(A, symbfact) linwork = lininit(1, cholfact) cache.cacheval = (cholfact, cholwork, linwork) diff --git a/src/factorization.jl b/src/factorization.jl index 6cd6cd571..0a9db3ab6 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -1162,8 +1162,8 @@ end """ CliqueTreesFactorization( - alg = CliqueTrees.DEFAULT_ELIMINATION_ALGORITHM, - snd = CliqueTrees.DEFAULT_SUPERNODE_TYPE, + alg = nothing, + snd = nothing, reuse_symbolic = true, ) @@ -1175,6 +1175,22 @@ struct CliqueTreesFactorization{A, S} <: AbstractSparseFactorization alg::A snd::S reuse_symbolic::Bool + + function CliqueTreesFactorization(; + alg::A = nothing, + snd::S = nothing, + reuse_symbolic = true, + throwerror = true, + ) where {A, S} + + ext = Base.get_extension(@__MODULE__, :LinearSolveCliqueTreesExt) + + if throwerror && isnothing(ext) + error("CliqueTreesFactorization requires that CliqueTrees is loaded, i.e. `using CliqueTrees`") + else + new{A, S}(alg, snd, reuse_symbolic) + end + end end function init_cacheval(::CliqueTreesFactorization, ::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr,