Skip to content

Commit 6e9bdfa

Browse files
Support Apple Accelerate and improve MKL integration
1 parent 464156c commit 6e9bdfa

File tree

3 files changed

+68
-1
lines changed

3 files changed

+68
-1
lines changed

ext/LinearSolveMKLExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
3939
A = convert(AbstractMatrix, A)
4040
if cache.isfresh
4141
cacheval = @get_cacheval(cache, :MKLLUFactorization)
42-
fact = LU(getrf!(A)...)
42+
fact = LU(getrf!(A)...; ipiv = fact.ipiv)
4343
cache.cacheval = fact
4444
cache.isfresh = false
4545
end

src/LinearSolve.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ using EnumX
2323
using Requires
2424
import InteractiveUtils
2525

26+
using LinearAlgebra: BlasInt, LU
27+
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
28+
@blasfunc, chkargsok
29+
2630
import GPUArraysCore
2731
import Preferences
2832

@@ -87,6 +91,7 @@ include("solve_function.jl")
8791
include("default.jl")
8892
include("init.jl")
8993
include("extension_algs.jl")
94+
include("appleaccelerate.jl")
9095
include("deprecated.jl")
9196

9297
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;

src/appleaccelerate.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# For now, only use BLAS from Accelerate (that is to say, vecLib)
2+
global const libacc = "/System/Library/Frameworks/Accelerate.framework/Accelerate"
3+
global const libacc_info_plist = "/System/Library/Frameworks/Accelerate.framework/Versions/Current/Resources/Info.plist"
4+
5+
"""
6+
```julia
7+
AppleAccelerateLUFactorization()
8+
```
9+
10+
A wrapper over Apple's Accelerate Library. Direct calls to Acceelrate in a way that pre-allocates workspace
11+
to avoid allocations and does not require libblastrampoline.
12+
"""
13+
struct AppleAccelerateLUFactorization <: AbstractFactorization end
14+
15+
function is_new_accelerate_available()
16+
libacc_hdl = dlopen_e(libacc)
17+
if libacc_hdl == C_NULL
18+
return false
19+
end
20+
21+
if dlsym_e(libacc_hdl, "dgemm\$NEWLAPACK\$ILP64") == C_NULL
22+
return false
23+
end
24+
return true
25+
end
26+
27+
function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false)
28+
require_one_based_indexing(A)
29+
check && chkfinite(A)
30+
chkstride1(A)
31+
m, n = size(A)
32+
lda = max(1,stride(A, 2))
33+
ccall(("dgemm\$NEWLAPACK\$ILP64", libacc), Cvoid,
34+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
35+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
36+
m, n, A, lda, ipiv, info)
37+
chkargsok(info[])
38+
A, ipiv, info[] #Error code is stored in LU factorization type
39+
end
40+
41+
default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
42+
default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
43+
44+
function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr,
45+
maxiters::Int, abstol, reltol, verbose::Bool,
46+
assumptions::OperatorAssumptions)
47+
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
48+
end
49+
50+
function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;
51+
kwargs...)
52+
A = cache.A
53+
A = convert(AbstractMatrix, A)
54+
if cache.isfresh
55+
cacheval = @get_cacheval(cache, :AppleAccelerateLUFactorization)
56+
fact = LU(aa_getrf!(A)...; ipiv = fact.ipiv)
57+
cache.cacheval = fact
58+
cache.isfresh = false
59+
end
60+
y = ldiv!(cache.u, @get_cacheval(cache, :AppleAccelerateLUFactorization), cache.b)
61+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
62+
end

0 commit comments

Comments
 (0)