diff --git a/ext/LinearSolveCUDAExt.jl b/ext/LinearSolveCUDAExt.jl index 48840c097..9ec8224b3 100644 --- a/ext/LinearSolveCUDAExt.jl +++ b/ext/LinearSolveCUDAExt.jl @@ -3,7 +3,8 @@ module LinearSolveCUDAExt using CUDA using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLinearSolver, DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache, needs_concrete_A, - error_no_cudss_lu, CUDSS_LOADED, init_cacheval + error_no_cudss_lu, init_cacheval, OperatorAssumptions, CudaOffloadFactorization, + SparspakFactorization, KLUFactorization, UMFPACKFactorization using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface using SciMLBase: AbstractSciMLOperator @@ -25,7 +26,7 @@ function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b, end function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR) - if !LinearSolve.CUDSS_LOADED[] + if !LinearSolve.cudss_loaded(A) error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.") end nothing