Skip to content

Commit 972c1bc

Browse files
Make AppleAccelerate the default algorithm when available
The benchmarks show that it's just so much better, and there's no cost to doing this, so we should just always use AppleAccelerate on any machine where it exists.
1 parent 0f21a9d commit 972c1bc

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

src/LinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
9090
SVDFactorization
9191
CholeskyFactorization
9292
NormalCholeskyFactorization
93+
AppleAccelerateLUFactorization
9394
end
9495

9596
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
@@ -98,6 +99,7 @@ end
9899

99100
include("common.jl")
100101
include("factorization.jl")
102+
include("appleaccelerate.jl")
101103
include("simplelu.jl")
102104
include("simplegmres.jl")
103105
include("iterative_wrappers.jl")
@@ -106,7 +108,6 @@ include("solve_function.jl")
106108
include("default.jl")
107109
include("init.jl")
108110
include("extension_algs.jl")
109-
include("appleaccelerate.jl")
110111
include("deprecated.jl")
111112

112113
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;

src/appleaccelerate.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ function aa_getrf!(A::AbstractMatrix{<:Float64};
3838
if isempty(ipiv)
3939
ipiv = similar(A, Cint, min(size(A, 1), size(A, 2)))
4040
end
41-
4241
ccall(("dgetrf_", libacc), Cvoid,
4342
(Ref{Cint}, Ref{Cint}, Ptr{Float64},
4443
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
@@ -121,11 +120,16 @@ end
121120
default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
122121
default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
123122

123+
const PREALLOCATED_APPLE_LU = begin
124+
A = rand(0, 0)
125+
luinst = ArrayInterface.lu_instance(A)
126+
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
127+
end
128+
124129
function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr,
125130
maxiters::Int, abstol, reltol, verbose::Bool,
126131
assumptions::OperatorAssumptions)
127-
luinst = ArrayInterface.lu_instance(convert(AbstractMatrix, A))
128-
LU(luinst.factors, similar(A, Cint, 0), luinst.info), Ref{Cint}()
132+
PREALLOCATED_APPLE_LU
129133
end
130134

131135
function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;

src/default.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
needs_concrete_A(alg::DefaultLinearSolver) = true
22
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3-
T13, T14, T15, T16}
3+
T13, T14, T15, T16, T17}
44
LUFactorization::T1
55
QRFactorization::T2
66
DiagonalFactorization::T3
@@ -17,6 +17,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
1717
SVDFactorization::T14
1818
CholeskyFactorization::T15
1919
NormalCholeskyFactorization::T16
20+
AppleAccelerateLUFactorization::T17
2021
end
2122

2223
# Legacy fallback
@@ -157,7 +158,9 @@ function defaultalg(A, b, assump::OperatorAssumptions)
157158
ArrayInterface.can_setindex(b) &&
158159
(__conditioning(assump) === OperatorCondition.IllConditioned ||
159160
__conditioning(assump) === OperatorCondition.WellConditioned)
160-
if length(b) <= 10
161+
if appleaccelerate_isavailable()
162+
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
163+
elseif length(b) <= 10
161164
DefaultAlgorithmChoice.GenericLUFactorization
162165
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
163166
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
@@ -232,6 +235,8 @@ function algchoice_to_alg(alg::Symbol)
232235
CholeskyFactorization()
233236
elseif alg === :NormalCholeskyFactorization
234237
NormalCholeskyFactorization()
238+
elseif alg === :AppleAccelerateLUFactorization
239+
AppleAccelerateLUFactorization()
235240
else
236241
error("Algorithm choice symbol $alg not allowed in the default")
237242
end

0 commit comments

Comments
 (0)