Skip to content

Commit 0eda0bb

Browse files
Fix CUDSS dispatches
Fixes #625
1 parent 1d2fa41 commit 0eda0bb

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ using LinearSolve
55
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
66
using SciMLBase: AbstractSciMLOperator
77

8+
function LinearSolve.is_cusparse(A::Union{CUDA.CUSPARSE.CuSparseMatrixCSR, CUDA.CUSPARSE.CuSparseMatrixCSC})
9+
true
10+
end
11+
812
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
913
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
1014
if LinearSolve.cudss_loaded(A)

ext/LinearSolveSparseArraysExt.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,31 @@ function LinearSolve.init_cacheval(
100100
Pl, Pr,
101101
maxiters::Int, abstol, reltol,
102102
verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES}
103-
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
103+
if is_cusparse(A)
104+
ArrayInterface.lu_instance(A)
105+
else
106+
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
107+
end
104108
end
105109

106110
function LinearSolve.init_cacheval(
107111
alg::LUFactorization, A::AbstractSparseArray{T, Int32}, b, u,
108112
Pl, Pr,
109113
maxiters::Int, abstol, reltol,
110114
verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES}
111-
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
115+
if LinearSolve.is_cusparse(A)
116+
ArrayInterface.lu_instance(A)
117+
else
118+
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
119+
end
120+
end
121+
122+
function LinearSolve.init_cacheval(
123+
alg::LUFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
124+
Pl, Pr,
125+
maxiters::Int, abstol, reltol,
126+
verbose::Bool, assumptions::OperatorAssumptions)
127+
ArrayInterface.lu_instance(A)
112128
end
113129

114130
function LinearSolve.init_cacheval(
@@ -120,6 +136,14 @@ function LinearSolve.init_cacheval(
120136
PREALLOCATED_UMFPACK
121137
end
122138

139+
function LinearSolve.init_cacheval(
140+
alg::UMFPACKFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
141+
Pl, Pr,
142+
maxiters::Int, abstol, reltol,
143+
verbose::Bool, assumptions::OperatorAssumptions)
144+
nothing
145+
end
146+
123147
function LinearSolve.init_cacheval(
124148
alg::UMFPACKFactorization, A::AbstractSparseArray{T, Int64}, b, u,
125149
Pl, Pr,
@@ -191,6 +215,14 @@ function LinearSolve.init_cacheval(
191215
PREALLOCATED_KLU
192216
end
193217

218+
function LinearSolve.init_cacheval(
219+
alg::KLUFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
220+
Pl, Pr,
221+
maxiters::Int, abstol, reltol,
222+
verbose::Bool, assumptions::OperatorAssumptions)
223+
nothing
224+
end
225+
194226
function LinearSolve.init_cacheval(
195227
alg::KLUFactorization, A::AbstractSparseArray{Float64, Int32}, b, u, Pl, Pr,
196228
maxiters::Int, abstol,

src/LinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ end
217217
ALREADY_WARNED_CUDSS = Ref{Bool}(false)
218218
error_no_cudss_lu(A) = nothing
219219
cudss_loaded(A) = false
220+
is_cusparse(A) = false
220221

221222
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
222223
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,

0 commit comments

Comments
 (0)