Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
- "pre"
group:
- "Core"
- "DefaultsLoading"
- "LinearSolveHYPRE"
- "LinearSolvePardiso"
- "LinearSolveBandedMatrices"
Expand Down
15 changes: 8 additions & 7 deletions ext/LinearSolveSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module LinearSolveSparseArraysExt
using LinearSolve, LinearAlgebra
using SparseArrays
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr
using LinearSolve: BLASELTYPES

# Can't `using KLU` because cannot have a dependency in there without
# requiring the user does `using KLU`
Expand Down Expand Up @@ -39,7 +40,7 @@ function LinearSolve.handle_sparsematrixcsc_lu(A::AbstractSparseMatrixCSC)
end

function LinearSolve.defaultalg(
A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
A::Symmetric{<:BLASELTYPES, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.CHOLMODFactorization)
end

Expand Down Expand Up @@ -78,7 +79,7 @@ function LinearSolve.init_cacheval(
end

function LinearSolve.init_cacheval(
alg::UMFPACKFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
alg::UMFPACKFactorization, A::AbstractSparseArray{Float64}, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
Expand Down Expand Up @@ -136,7 +137,7 @@ function LinearSolve.init_cacheval(
end

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::AbstractSparseArray, b, u, Pl, Pr,
alg::KLUFactorization, A::AbstractSparseArray{Float64}, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
Expand Down Expand Up @@ -186,15 +187,15 @@ function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions) where {T <:
Union{Float32, Float64}}
BLASELTYPES}
PREALLOCATED_CHOLMOD
end

function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
A::Union{AbstractSparseArray, LinearSolve.GPUArraysCore.AnyGPUArray,
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
A::Union{AbstractSparseArray{T}, LinearSolve.GPUArraysCore.AnyGPUArray,
Symmetric{T, <:AbstractSparseArray{T}}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
LinearSolve.ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
end

Expand Down
2 changes: 1 addition & 1 deletion src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ function algchoice_to_alg(alg::Symbol)
elseif alg === :DirectLdiv!
DirectLdiv!()
elseif alg === :SparspakFactorization
SparspakFactorization()
SparspakFactorization(throwerror = false)
elseif alg === :KLUFactorization
KLUFactorization()
elseif alg === :UMFPACKFactorization
Expand Down
17 changes: 13 additions & 4 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ function init_cacheval(alg::CholeskyFactorization, A::GPUArraysCore.AnyGPUArray,
cholesky(A; check = false)
end

function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
function init_cacheval(alg::CholeskyFactorization, A::AbstractArray{<:BLASELTYPES}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
end
Expand All @@ -333,7 +333,7 @@ function init_cacheval(alg::CholeskyFactorization, A::Matrix{Float64}, b, u, Pl,
end

function init_cacheval(alg::CholeskyFactorization,
A::Union{Diagonal, AbstractSciMLOperator}, b, u, Pl, Pr,
A::Union{Diagonal, AbstractSciMLOperator, AbstractArray}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
nothing
Expand Down Expand Up @@ -1044,8 +1044,17 @@ dispatch to route around standard BLAS routines in the case e.g. of arbitrary-pr
floating point numbers or ForwardDiff.Dual.
This e.g. allows for Automatic Differentiation (AD) of a sparse-matrix solve.
"""
Base.@kwdef struct SparspakFactorization <: AbstractSparseFactorization
reuse_symbolic::Bool = true
struct SparspakFactorization <: AbstractSparseFactorization
reuse_symbolic::Bool

function SparspakFactorization(;reuse_symbolic = true, throwerror = true)
ext = Base.get_extension(@__MODULE__, :LinearSolveSparspakExt)
if throwerror && ext === nothing
error("SparspakFactorization requires that Sparspak is loaded, i.e. `using Sparspak`")
else
new(reuse_symbolic)
end
end
end

function init_cacheval(alg::SparspakFactorization,
Expand Down
34 changes: 34 additions & 0 deletions test/defaults_loading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using SparseArrays
using LinearSolve
using Test

n = 10
dx = 1/n
dx2 = dx^-2
vals = Vector{BigFloat}(undef, 0)
cols = Vector{Int}(undef, 0)
rows = Vector{Int}(undef, 0)
for i in 1:n
if i != 1
push!(vals, dx2)
push!(cols, i-1)
push!(rows, i)
end
push!(vals, -2dx2)
push!(cols, i)
push!(rows, i)
if i != n
push!(vals, dx2)
push!(cols, i+1)
push!(rows, i)
end
end
mat = sparse(rows, cols, vals, n, n)
rhs = big.(zeros(n))
rhs[begin] = rhs[end] = -2
prob = LinearProblem(mat, rhs)
@test_throws ["SparspakFactorization required", "using Sparspak"] sol = solve(prob).u

using Sparspak
sol = solve(prob).u
@test sol isa Vector{BigFloat}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test sol isa Vector{BigFloat}
@test sol isa Vector{BigFloat}

4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ if GROUP == "All" || GROUP == "Enzyme"
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
end

if GROUP == "All" || GROUP == "DefaultsLoading"
@time @safetestset "Enzyme Derivative Rules" include("defaults_loading.jl")
end

if GROUP == "LinearSolveCUDA"
Pkg.activate("gpu")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Expand Down
Loading