Skip to content

Commit c0e1e13

Browse files
authored
Revert "Fix default algorithm for sparse CUDA matrices to LUFactorization"
1 parent 0c0f1e4 commit c0e1e13

File tree

2 files changed

+19
-44
lines changed

2 files changed

+19
-44
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
55
DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
66
needs_concrete_A,
77
error_no_cudss_lu, init_cacheval, OperatorAssumptions,
8-
CudaOffloadFactorization, CudaOffloadLUFactorization,
9-
CudaOffloadQRFactorization,
8+
CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
109
CUDAOffload32MixedLUFactorization,
11-
SparspakFactorization, KLUFactorization, UMFPACKFactorization,
12-
LinearVerbosity
10+
SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity
1311
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
1412
using SciMLBase: AbstractSciMLOperator
1513

@@ -25,16 +23,11 @@ function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
2523
if LinearSolve.cudss_loaded(A)
2624
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
2725
else
28-
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.")
29-
end
30-
end
31-
32-
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSC{Tv, Ti}, b,
33-
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
34-
if LinearSolve.cudss_loaded(A)
35-
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
36-
else
37-
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.")
26+
if !LinearSolve.ALREADY_WARNED_CUDSS[]
27+
@warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov")
28+
LinearSolve.ALREADY_WARNED_CUDSS[] = true
29+
end
30+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
3831
end
3932
end
4033

@@ -45,13 +38,6 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
4538
nothing
4639
end
4740

48-
function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSC)
49-
if !LinearSolve.cudss_loaded(A)
50-
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.")
51-
end
52-
nothing
53-
end
54-
5541
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization;
5642
kwargs...)
5743
if cache.isfresh
@@ -66,15 +52,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
6652
SciMLBase.build_linear_solution(alg, y, nothing, cache)
6753
end
6854

69-
function LinearSolve.init_cacheval(
70-
alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
55+
function LinearSolve.init_cacheval(alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
7156
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
7257
assumptions::OperatorAssumptions)
7358
# Check if CUDA is functional before creating CUDA arrays
7459
if !CUDA.functional()
7560
return nothing
7661
end
77-
62+
7863
T = eltype(A)
7964
noUnitT = typeof(zero(T))
8065
luT = LinearAlgebra.lutype(noUnitT)
@@ -102,7 +87,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl,
10287
if !CUDA.functional()
10388
return nothing
10489
end
105-
90+
10691
qr(CUDA.CuArray(A))
10792
end
10893

@@ -119,42 +104,35 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
119104
SciMLBase.build_linear_solution(alg, y, nothing, cache)
120105
end
121106

122-
function LinearSolve.init_cacheval(
123-
alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
107+
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
124108
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
125109
assumptions::OperatorAssumptions)
126110
qr(CUDA.CuArray(A))
127111
end
128112

129113
function LinearSolve.init_cacheval(
130114
::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
131-
Pl, Pr, maxiters::Int, abstol, reltol,
132-
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
115+
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
133116
nothing
134117
end
135118

136119
function LinearSolve.init_cacheval(
137120
::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
138-
Pl, Pr, maxiters::Int, abstol, reltol,
139-
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
121+
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
140122
nothing
141123
end
142124

143125
function LinearSolve.init_cacheval(
144126
::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
145-
Pl, Pr, maxiters::Int, abstol, reltol,
146-
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
127+
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
147128
nothing
148129
end
149130

150131
# Mixed precision CUDA LU implementation
151-
function SciMLBase.solve!(
152-
cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
132+
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
153133
kwargs...)
154134
if cache.isfresh
155-
fact, A_gpu_f32,
156-
b_gpu_f32,
157-
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
135+
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
158136
# Compute 32-bit type on demand and convert
159137
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
160138
A_f32 = T32.(cache.A)
@@ -163,14 +141,12 @@ function SciMLBase.solve!(
163141
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
164142
cache.isfresh = false
165143
end
166-
fact, A_gpu_f32,
167-
b_gpu_f32,
168-
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
169-
144+
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
145+
170146
# Compute types on demand for conversions
171147
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
172148
Torig = eltype(cache.u)
173-
149+
174150
# Convert b to Float32, solve, then convert back to original precision
175151
b_f32 = T32.(cache.b)
176152
copyto!(b_gpu_f32, b_f32)

ext/LinearSolveCUDSSExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ using LinearSolve: LinearSolve, cudss_loaded
44
using CUDSS
55

66
LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSR) = true
7-
LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSC) = true
87

98
end

0 commit comments

Comments
 (0)