Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ using LinearSolve
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
using SciMLBase: AbstractSciMLOperator

function LinearSolve.is_cusparse(A::Union{CUDA.CUSPARSE.CuSparseMatrixCSR, CUDA.CUSPARSE.CuSparseMatrixCSC})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function LinearSolve.is_cusparse(A::Union{CUDA.CUSPARSE.CuSparseMatrixCSR, CUDA.CUSPARSE.CuSparseMatrixCSC})
function LinearSolve.is_cusparse(A::Union{
CUDA.CUSPARSE.CuSparseMatrixCSR, CUDA.CUSPARSE.CuSparseMatrixCSC})

true
end

function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
if LinearSolve.cudss_loaded(A)
Expand Down
36 changes: 34 additions & 2 deletions ext/LinearSolveSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,31 @@ function LinearSolve.init_cacheval(
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES}
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
if is_cusparse(A)
ArrayInterface.lu_instance(A)
else
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(
zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))

end
end

function LinearSolve.init_cacheval(
alg::LUFactorization, A::AbstractSparseArray{T, Int32}, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES}
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
if LinearSolve.is_cusparse(A)
ArrayInterface.lu_instance(A)
else
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(
zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))

end
end

function LinearSolve.init_cacheval(
alg::LUFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(A)
end

function LinearSolve.init_cacheval(
Expand All @@ -120,6 +136,14 @@ function LinearSolve.init_cacheval(
PREALLOCATED_UMFPACK
end

function LinearSolve.init_cacheval(
alg::UMFPACKFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function LinearSolve.init_cacheval(
alg::UMFPACKFactorization, A::AbstractSparseArray{T, Int64}, b, u,
Pl, Pr,
Expand Down Expand Up @@ -191,6 +215,14 @@ function LinearSolve.init_cacheval(
PREALLOCATED_KLU
end

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::AbstractSparseArray{Float64, Int32}, b, u, Pl, Pr,
maxiters::Int, abstol,
Expand Down
1 change: 1 addition & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ end
ALREADY_WARNED_CUDSS = Ref{Bool}(false)
error_no_cudss_lu(A) = nothing
cudss_loaded(A) = false
is_cusparse(A) = false

export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
Expand Down
1 change: 1 addition & 0 deletions test/gpu/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
15 changes: 15 additions & 0 deletions test/gpu/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LinearSolve, CUDA, LinearAlgebra, SparseArrays, StableRNGs
using CUDA.CUSPARSE, CUDSS
using Test

CUDA.allowscalar(false)
Expand Down Expand Up @@ -91,3 +92,17 @@ prob2 = LinearProblem(transpose(A), b)
sol = solve(prob2, alg; alias = LinearAliasSpecifier(alias_A = false))
@test norm(transpose(A) * sol.u .- b) < 1e-5
end

@testset "CUDSS" begin
T = Float32
n = 100
A_cpu = sprand(T, n, n, 0.05) + I
x_cpu = zeros(T, n)
b_cpu = rand(T, n)

A_gpu_csr = CuSparseMatrixCSR(A_cpu)
b_gpu = CuVector(b_cpu)

prob = LinearProblem(A_gpu_csr, b_gpu)
sol = solve(prob)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

Loading