Skip to content
Merged
24 changes: 17 additions & 7 deletions ext/LinearSolveSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ end
function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
ext = Base.get_extension(LinearSolve, :LinearSolveSparspakExt)
@show "here2", Tv, Ti, typeof(A)
if assump.issq && ext !== nothing
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.SparspakFactorization)
elseif !assump.issq
Expand Down Expand Up @@ -83,10 +84,14 @@ function LinearSolve.init_cacheval(
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
A = convert(AbstractMatrix, A)
zerobased = SparseArrays.getcolptr(A)[1] == 0
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
rowvals(A), nonzeros(A)))
if size(A,1) == size(A,2)
A = convert(AbstractMatrix, A)
zerobased = SparseArrays.getcolptr(A)[1] == 0
return SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(size(A)..., getcolptr(A),
rowvals(A), nonzeros(A)))
else
PREALLOCATED_UMFPACK
end
end

function SciMLBase.solve!(
Expand Down Expand Up @@ -141,9 +146,13 @@ function LinearSolve.init_cacheval(
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
A = convert(AbstractMatrix, A)
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
if size(A,1) == size(A,2)
A = convert(AbstractMatrix, A)
return KLU.KLUFactorization(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)))
else
PREALLOCATED_KLU
end
end

# TODO: guard this against errors
Expand Down Expand Up @@ -238,6 +247,7 @@ end
function LinearSolve.defaultalg(
A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Ti}
@show "here"
if assump.issq
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KLUFactorization)
Expand Down
25 changes: 15 additions & 10 deletions ext/LinearSolveSparspakExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,23 @@ function LinearSolve.init_cacheval(
end

function LinearSolve.init_cacheval(
::SparspakFactorization, A::AbstractSparseMatrixCSC, b, u, Pl, Pr, maxiters::Int, abstol,
::SparspakFactorization, A::AbstractSparseMatrixCSC{Tv, Ti}, b, u, Pl, Pr, maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
A = convert(AbstractMatrix, A)
if A isa SparseArrays.AbstractSparseArray
return sparspaklu(
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
factorize = false)
verbose::Bool, assumptions::OperatorAssumptions) where {Tv, Ti}

if size(A,1) == size(A,2)
Comment on lines +23 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if size(A,1) == size(A,2)
if size(A, 1) == size(A, 2)

A = convert(AbstractMatrix, A)
if A isa SparseArrays.AbstractSparseArray
return sparspaklu(
SparseMatrixCSC{Tv, Ti}(size(A)..., getcolptr(A), rowvals(A),
nonzeros(A)),
factorize = false)
else
return sparspaklu(SparseMatrixCSC(0, 0, [one(Ti)], Ti[], eltype(A)[]),
factorize = false)
end
else
return sparspaklu(SparseMatrixCSC(0, 0, [1], Int[], eltype(A)[]),
factorize = false)
PREALLOCATED_SPARSEPAK
end
end

Expand Down
28 changes: 28 additions & 0 deletions test/default_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,31 @@ cache.A = [2.0 1.0
sol = solve!(cache)

@test !SciMLBase.successful_retcode(sol.retcode)

## Non-square Sparse Defaults
# https://github.com/SciML/NonlinearSolve.jl/issues/599
A = SparseMatrixCSC{Float64, Int64}([
1.0 0.0
1.0 1.0
])
Comment on lines +150 to +153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
A = SparseMatrixCSC{Float64, Int64}([
1.0 0.0
1.0 1.0
])
A = SparseMatrixCSC{Float64, Int64}([1.0 0.0
1.0 1.0])

b = ones(2)
A2 = hcat(A,A)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
A2 = hcat(A,A)
A2 = hcat(A, A)

prob = LinearProblem(A, b)
@test SciMLBase.successful_retcode(solve(prob))

prob2 = LinearProblem(A2, b)
@test SciMLBase.successful_retcode(solve(prob2))

A = SparseMatrixCSC{Float64, Int32}([
1.0 0.0
1.0 1.0
])
Comment on lines +162 to +165
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
A = SparseMatrixCSC{Float64, Int32}([
1.0 0.0
1.0 1.0
])
A = SparseMatrixCSC{Float64, Int32}([1.0 0.0
1.0 1.0])

b = ones(2)
A2 = hcat(A,A)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
A2 = hcat(A,A)
A2 = hcat(A, A)

prob = LinearProblem(A, b)
@test SciMLBase.successful_retcode(solve(prob))

@info "This test"

prob2 = LinearProblem(A2, b)
@test SciMLBase.successful_retcode(solve(prob2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test SciMLBase.successful_retcode(solve(prob2))
@test SciMLBase.successful_retcode(solve(prob2))

Loading