Skip to content

Commit 78757ee

Browse files
Merge pull request #96 from SciML/defaults
much better defaults
2 parents 4e0a46d + 83bcb78 commit 78757ee

File tree

4 files changed

+42
-26
lines changed

4 files changed

+42
-26
lines changed

src/LinearSolve.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,11 @@ include("factorization.jl")
3939
include("iterative_wrappers.jl")
4040
include("preconditioners.jl")
4141
include("default.jl")
42+
include("init.jl")
4243

4344
const IS_OPENBLAS = Ref(true)
4445
isopenblas() = IS_OPENBLAS[]
4546

46-
function __init__()
47-
@static if VERSION < v"1.7beta"
48-
blas = BLAS.vendor()
49-
IS_OPENBLAS[] = blas == :openblas64 || blas == :openblas
50-
else
51-
IS_OPENBLAS[] = occursin("openblas", BLAS.get_config().loaded_libs[1].libname)
52-
end
53-
54-
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
55-
@require Pardiso="46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" include("pardiso.jl")
56-
end
57-
5847
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
5948
RFLUFactorization, UMFPACKFactorization, KLUFactorization
6049
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,

src/cuda.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,26 @@
1+
gpu_or_cpu(x::CUDA.CuArray) = CUDA.CuArray
2+
gpu_or_cpu(x::Transpose{<:Any,<:CUDA.CuArray}) = CUDA.CuArray
3+
gpu_or_cpu(x::Adjoint{<:Any,<:CUDA.CuArray}) = CUDA.CuArray
4+
isgpu(::CUDA.CuArray) = true
5+
isgpu(::Transpose{<:Any,<:CUDA.CuArray}) = true
6+
isgpu(::Adjoint{<:Any,<:CUDA.CuArray}) = true
7+
ifgpufree(x::CUDA.CuArray) = CUDA.unsafe_free!(x)
8+
ifgpufree(x::Transpose{<:Any,<:CUDA.CuArray}) = CUDA.unsafe_free!(x.parent)
9+
ifgpufree(x::Adjoint{<:Any,<:CUDA.CuArray}) = CUDA.unsafe_free!(x.parent)
10+
11+
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
12+
TrackedArray = Tracker.TrackedArray
13+
gpu_or_cpu(x::TrackedArray{<:Any,<:Any,<:CUDA.CuArray}) = CUDA.CuArray
14+
gpu_or_cpu(x::Adjoint{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = CUDA.CuArray
15+
gpu_or_cpu(x::Transpose{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = CUDA.CuArray
16+
isgpu(::Adjoint{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = true
17+
isgpu(::TrackedArray{<:Any,<:Any,<:CUDA.CuArray}) = true
18+
isgpu(::Transpose{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = true
19+
ifgpufree(x::TrackedArray{<:Any,<:Any,<:CUDA.CuArray}) = CUDA.unsafe_free!(x.data)
20+
ifgpufree(x::Adjoint{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = CUDA.unsafe_free!((x.data).parent)
21+
ifgpufree(x::Transpose{<:Any,TrackedArray{<:Any,<:Any,<:CUDA.CuArray}}) = CUDA.unsafe_free!((x.data).parent)
22+
end
23+
124
struct GPUOffloadFactorization <: AbstractFactorization end
225

326
function SciMLBase.solve(cache::LinearCache, alg::GPUOffloadFactorization; kwargs...)

src/default.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function defaultalg(A,b)
99
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
1010
# it makes sense according to the benchmarks, which is dependent on
1111
# whether MKL or OpenBLAS is being used
12-
if A === nothing || A isa Matrix
12+
if (A === nothing && !isgpu(b)) || A isa Matrix
1313
if (A === nothing || eltype(A) <: Union{Float32,Float64,ComplexF32,ComplexF64}) &&
1414
ArrayInterface.can_setindex(b) && (length(b) <= 100 ||
1515
(isopenblas() && length(b) <= 500)
@@ -30,18 +30,15 @@ function defaultalg(A,b)
3030

3131
# This catches the cases where a factorization overload could exist
3232
# For example, BlockBandedMatrix
33-
elseif ArrayInterface.isstructured(A)
33+
elseif A !== nothing && ArrayInterface.isstructured(A)
3434
alg = GenericFactorization()
3535

3636
# This catches the case where A is a CuMatrix
3737
# Which does not have LU fully defined
38-
elseif !(A isa AbstractDiffEqOperator)
38+
elseif isgpu(A) || isgpu(b)
3939
alg = QRFactorization(false)
4040

4141
# Not factorizable operator, default to only using A*x
42-
# IterativeSolvers is faster on CPU but not GPU-compatible
43-
elseif cache.u isa Array
44-
alg = IterativeSolversJL_GMRES()
4542
else
4643
alg = KrylovJL_GMRES()
4744
end
@@ -92,15 +89,12 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
9289

9390
# This catches the case where A is a CuMatrix
9491
# Which does not have LU fully defined
95-
elseif !(A isa AbstractDiffEqOperator)
92+
elseif isgpu(A)
9693
alg = QRFactorization(false)
9794
SciMLBase.solve(cache, alg, args...; kwargs...)
9895

9996
# Not factorizable operator, default to only using A*x
10097
# IterativeSolvers is faster on CPU but not GPU-compatible
101-
elseif cache.u isa Array
102-
alg = IterativeSolversJL_GMRES()
103-
SciMLBase.solve(cache, alg, args...; kwargs...)
10498
else
10599
alg = KrylovJL_GMRES()
106100
SciMLBase.solve(cache, alg, args...; kwargs...)
@@ -147,15 +141,12 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
147141

148142
# This catches the case where A is a CuMatrix
149143
# Which does not have LU fully defined
150-
elseif !(A isa AbstractDiffEqOperator)
144+
elseif isgpu(A)
151145
alg = QRFactorization(false)
152146
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
153147

154148
# Not factorizable operator, default to only using A*x
155149
# IterativeSolvers is faster on CPU but not GPU-compatible
156-
elseif u isa Array
157-
alg = IterativeSolversJL_GMRES()
158-
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
159150
else
160151
alg = KrylovJL_GMRES()
161152
init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)

src/init.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
isgpu(x) = false
2+
ifgpufree(x) = nothing
3+
function __init__()
4+
@static if VERSION < v"1.7beta"
5+
blas = BLAS.vendor()
6+
IS_OPENBLAS[] = blas == :openblas64 || blas == :openblas
7+
else
8+
IS_OPENBLAS[] = occursin("openblas", BLAS.get_config().loaded_libs[1].libname)
9+
end
10+
11+
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
12+
@require Pardiso="46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" include("pardiso.jl")
13+
end

0 commit comments

Comments
 (0)