Skip to content

Commit e487d3b

Browse files
Merge pull request #386 from SciML/appleaccelerate_default
Make AppleAccelerate the default algorithm when available
2 parents 0f21a9d + 6413b06 commit e487d3b

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

src/LinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
9090
SVDFactorization
9191
CholeskyFactorization
9292
NormalCholeskyFactorization
93+
AppleAccelerateLUFactorization
9394
end
9495

9596
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
@@ -98,6 +99,7 @@ end
9899

99100
include("common.jl")
100101
include("factorization.jl")
102+
include("appleaccelerate.jl")
101103
include("simplelu.jl")
102104
include("simplegmres.jl")
103105
include("iterative_wrappers.jl")
@@ -106,7 +108,6 @@ include("solve_function.jl")
106108
include("default.jl")
107109
include("init.jl")
108110
include("extension_algs.jl")
109-
include("appleaccelerate.jl")
110111
include("deprecated.jl")
111112

112113
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;

src/appleaccelerate.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ function aa_getrf!(A::AbstractMatrix{<:Float64};
3838
if isempty(ipiv)
3939
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
4040
end
41-
4241
ccall(("dgetrf_", libacc), Cvoid,
4342
(Ref{Cint}, Ref{Cint}, Ptr{Float64},
4443
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
@@ -121,11 +120,18 @@ end
121120
default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
122121
default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
123122

123+
const PREALLOCATED_APPLE_LU = @static if VERSION >= v"1.8"
124+
A = rand(0, 0)
125+
luinst = ArrayInterface.lu_instance(A)
126+
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
127+
else
128+
nothing
129+
end
130+
124131
function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr,
125132
maxiters::Int, abstol, reltol, verbose::Bool,
126133
assumptions::OperatorAssumptions)
127-
luinst = ArrayInterface.lu_instance(convert(AbstractMatrix, A))
128-
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
134+
PREALLOCATED_APPLE_LU
129135
end
130136

131137
function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;

src/default.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
needs_concrete_A(alg::DefaultLinearSolver) = true
22
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3-
T13, T14, T15, T16}
3+
T13, T14, T15, T16, T17}
44
LUFactorization::T1
55
QRFactorization::T2
66
DiagonalFactorization::T3
@@ -17,6 +17,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
1717
SVDFactorization::T14
1818
CholeskyFactorization::T15
1919
NormalCholeskyFactorization::T16
20+
AppleAccelerateLUFactorization::T17
2021
end
2122

2223
# Legacy fallback
@@ -159,6 +160,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
159160
__conditioning(assump) === OperatorCondition.WellConditioned)
160161
if length(b) <= 10
161162
DefaultAlgorithmChoice.GenericLUFactorization
163+
elseif VERSION >= v"1.8" && appleaccelerate_isavailable()
164+
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
162165
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
163166
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
164167
eltype(A) <: Union{Float32, Float64})
@@ -232,6 +235,8 @@ function algchoice_to_alg(alg::Symbol)
232235
CholeskyFactorization()
233236
elseif alg === :NormalCholeskyFactorization
234237
NormalCholeskyFactorization()
238+
elseif alg === :AppleAccelerateLUFactorization
239+
AppleAccelerateLUFactorization()
235240
else
236241
error("Algorithm choice symbol $alg not allowed in the default")
237242
end

0 commit comments

Comments
 (0)