Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
needs_concrete_A,
error_no_cudss_lu, init_cacheval, OperatorAssumptions,
CudaOffloadFactorization,
CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
SparspakFactorization, KLUFactorization, UMFPACKFactorization
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
using SciMLBase: AbstractSciMLOperator
Expand Down Expand Up @@ -35,6 +35,43 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
nothing
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization;
kwargs...)
if cache.isfresh
fact = lu(CUDA.CuArray(cache.A))
cache.cacheval = fact
cache.isfresh = false
end
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
cache.u .= y
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

function LinearSolve.init_cacheval(alg::CudaOffloadLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
lu(CUDA.CuArray(A))
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadQRFactorization;
kwargs...)
if cache.isfresh
fact = qr(CUDA.CuArray(cache.A))
cache.cacheval = fact
cache.isfresh = false
end
y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b)))
cache.u .= y
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
qr(CUDA.CuArray(A))
end

# Keep the deprecated CudaOffloadFactorization working by forwarding to QR
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
kwargs...)
if cache.isfresh
Expand Down
6 changes: 3 additions & 3 deletions lib/LinearSolveAutotune/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ function get_gpu_algorithms(; skip_missing_algs::Bool = false)
# CUDA algorithms
if is_cuda_available()
try
push!(gpu_algs, CudaOffloadFactorization())
push!(gpu_names, "CudaOffloadFactorization")
push!(gpu_algs, CudaOffloadLUFactorization())
push!(gpu_names, "CudaOffloadLUFactorization")
catch e
msg = "CUDA hardware detected but CudaOffloadFactorization not available: $e. Load CUDA.jl package."
msg = "CUDA hardware detected but CudaOffloadLUFactorization not available: $e. Load CUDA.jl package."
if skip_missing_algs
@warn msg
else
Expand Down
2 changes: 2 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ export SimpleGMRES

export HYPREAlgorithm
export CudaOffloadFactorization
export CudaOffloadLUFactorization
export CudaOffloadQRFactorization
export MKLPardisoFactorize, MKLPardisoIterate
export PanuaPardisoFactorize, PanuaPardisoIterate
export PardisoJL
Expand Down
51 changes: 49 additions & 2 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,70 @@ struct HYPREAlgorithm <: SciMLLinearSolveAlgorithm
end
end

# Debug: About to define CudaOffloadLUFactorization
"""
`CudaOffloadLUFactorization()`

An offloading technique used to GPU-accelerate CPU-based computations using LU factorization.
Requires a sufficiently large `A` to overcome the data transfer costs.

!!! note

Using this solver requires adding the package CUDA.jl, i.e. `using CUDA`
"""
struct CudaOffloadLUFactorization <: AbstractFactorization
function CudaOffloadLUFactorization()
ext = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt)
if ext === nothing
error("CudaOffloadLUFactorization requires that CUDA is loaded, i.e. `using CUDA`")
else
return new()
end
end
end

"""
`CudaOffloadQRFactorization()`

An offloading technique used to GPU-accelerate CPU-based computations using QR factorization.
Requires a sufficiently large `A` to overcome the data transfer costs.

!!! note

Using this solver requires adding the package CUDA.jl, i.e. `using CUDA`
"""
struct CudaOffloadQRFactorization <: AbstractFactorization
function CudaOffloadQRFactorization()
ext = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt)
if ext === nothing
error("CudaOffloadQRFactorization requires that CUDA is loaded, i.e. `using CUDA`")
else
return new()
end
end
end

"""
`CudaOffloadFactorization()`

!!! warning
This algorithm is deprecated. Use `CudaOffloadQRFactorization()` instead.

An offloading technique used to GPU-accelerate CPU-based computations.
Requires a sufficiently large `A` to overcome the data transfer costs.

!!! note

Using this solver requires adding the package CUDA.jl, i.e. `using CUDA`
"""
struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization
struct CudaOffloadFactorization <: AbstractFactorization
function CudaOffloadFactorization()
Base.depwarn("`CudaOffloadFactorization` is deprecated, use `CudaOffloadQRFactorization` instead.", :CudaOffloadFactorization)
ext = Base.get_extension(@__MODULE__, :LinearSolveCUDAExt)
if ext === nothing
error("CudaOffloadFactorization requires that CUDA is loaded, i.e. `using CUDA`")
else
return new{}()
return new()
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/gpu/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function test_interface(alg, prob1, prob2)
return
end

@testset "$alg" for alg in (CudaOffloadFactorization(), NormalCholeskyFactorization())
@testset "$alg" for alg in (CudaOffloadFactorization(), CudaOffloadLUFactorization(), CudaOffloadQRFactorization(), NormalCholeskyFactorization())
test_interface(alg, prob1, prob2)
end

Expand Down
2 changes: 2 additions & 0 deletions test/resolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),
if !(alg in [
DiagonalFactorization,
CudaOffloadFactorization,
CudaOffloadLUFactorization,
CudaOffloadQRFactorization,
CUSOLVERRFFactorization,
AppleAccelerateLUFactorization,
MetalLUFactorization
Expand Down
Loading