Skip to content

Commit aabe2f2

Browse files
Merge pull request #355 from SciML/accelerate
Support Apple Accelerate and improve MKL integration
2 parents 464156c + 6d5aeb4 commit aabe2f2

File tree

6 files changed

+123
-5
lines changed

6 files changed

+123
-5
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1212
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1313
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
1414
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
15+
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1617
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1718
Preferences = "21216c6a-2e73-6563-6e65-726566657250"

ext/LinearSolveMKLExt.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(siz
1616
chkstride1(A)
1717
m, n = size(A)
1818
lda = max(1,stride(A, 2))
19+
if isempty(ipiv)
20+
ipiv = similar(A, BlasInt, min(size(A,1),size(A,2)))
21+
end
1922
ccall((@blasfunc(dgetrf_), MKL_jll.libmkl_rt), Cvoid,
2023
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
2124
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
2225
m, n, A, lda, ipiv, info)
2326
chkargsok(info[])
24-
A, ipiv, info[] #Error code is stored in LU factorization type
27+
A, ipiv, info[], info #Error code is stored in LU factorization type
2528
end
2629

2730
default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
@@ -30,7 +33,7 @@ default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false
3033
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
3134
maxiters::Int, abstol, reltol, verbose::Bool,
3235
assumptions::OperatorAssumptions)
33-
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
36+
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), Ref{BlasInt}()
3437
end
3538

3639
function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
@@ -39,11 +42,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
3942
A = convert(AbstractMatrix, A)
4043
if cache.isfresh
4144
cacheval = @get_cacheval(cache, :MKLLUFactorization)
42-
fact = LU(getrf!(A)...)
45+
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
46+
fact = LU(res[1:3]...), res[4]
4347
cache.cacheval = fact
4448
cache.isfresh = false
4549
end
46-
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization), cache.b)
50+
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization)[1], cache.b)
4751
SciMLBase.build_linear_solution(alg, y, nothing, cache)
4852
end
4953

src/LinearSolve.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ using EnumX
2323
using Requires
2424
import InteractiveUtils
2525

26+
using LinearAlgebra: BlasInt, LU
27+
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
28+
@blasfunc, chkargsok
29+
2630
import GPUArraysCore
2731
import Preferences
2832

@@ -87,6 +91,7 @@ include("solve_function.jl")
8791
include("default.jl")
8892
include("init.jl")
8993
include("extension_algs.jl")
94+
include("appleaccelerate.jl")
9095
include("deprecated.jl")
9196

9297
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization;
@@ -185,6 +190,7 @@ export CudaOffloadFactorization
185190
export MKLPardisoFactorize, MKLPardisoIterate
186191
export PardisoJL
187192
export MKLLUFactorization
193+
export AppleAccelerateLUFactorization
188194

189195
export OperatorAssumptions, OperatorCondition
190196

src/appleaccelerate.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
using LinearAlgebra
2+
using Libdl
3+
4+
# For now, only use BLAS from Accelerate (that is to say, vecLib)
5+
global const libacc = "/System/Library/Frameworks/Accelerate.framework/Accelerate"
6+
7+
"""
8+
```julia
9+
AppleAccelerateLUFactorization()
10+
```
11+
12+
A wrapper over Apple's Accelerate Library. Direct calls to Acceelrate in a way that pre-allocates workspace
13+
to avoid allocations and does not require libblastrampoline.
14+
"""
15+
struct AppleAccelerateLUFactorization <: AbstractFactorization end
16+
17+
function appleaccelerate_isavailable()
18+
libacc_hdl = Libdl.dlopen_e(libacc)
19+
if libacc_hdl == C_NULL
20+
return false
21+
end
22+
23+
if dlsym_e(libacc_hdl, "dgetrf_") == C_NULL
24+
return false
25+
end
26+
return true
27+
end
28+
29+
function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, Cint, min(size(A,1),size(A,2))), info = Ref{Cint}(), check = false)
30+
require_one_based_indexing(A)
31+
check && chkfinite(A)
32+
chkstride1(A)
33+
m, n = size(A)
34+
lda = max(1,stride(A, 2))
35+
if isempty(ipiv)
36+
ipiv = similar(A, Cint, min(size(A,1),size(A,2)))
37+
end
38+
39+
ccall(("dgetrf_", libacc), Cvoid,
40+
(Ref{Cint}, Ref{Cint}, Ptr{Float64},
41+
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
42+
m, n, A, lda, ipiv, info)
43+
info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_"))
44+
A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type
45+
end
46+
47+
function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float64}; info = Ref{Cint}())
48+
require_one_based_indexing(A, ipiv, B)
49+
LinearAlgebra.LAPACK.chktrans(trans)
50+
chkstride1(A, B, ipiv)
51+
n = LinearAlgebra.checksquare(A)
52+
if n != size(B, 1)
53+
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))
54+
end
55+
if n != length(ipiv)
56+
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))
57+
end
58+
nrhs = size(B, 2)
59+
ccall(("dgetrs_", libacc), Cvoid,
60+
(Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint},
61+
Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong),
62+
trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1)
63+
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
64+
B
65+
end
66+
67+
default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
68+
default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
69+
70+
function LinearSolve.init_cacheval(alg::AppleAccelerateLUFactorization, A, b, u, Pl, Pr,
71+
maxiters::Int, abstol, reltol, verbose::Bool,
72+
assumptions::OperatorAssumptions)
73+
luinst = ArrayInterface.lu_instance(convert(AbstractMatrix, A))
74+
LU(luinst.factors,similar(A, Cint, 0), luinst.info), Ref{Cint}()
75+
end
76+
77+
function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorization;
78+
kwargs...)
79+
A = cache.A
80+
A = convert(AbstractMatrix, A)
81+
if cache.isfresh
82+
cacheval = @get_cacheval(cache, :AppleAccelerateLUFactorization)
83+
res = aa_getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
84+
fact = LU(res[1:3]...), res[4]
85+
cache.cacheval = fact
86+
cache.isfresh = false
87+
end
88+
89+
A, info = @get_cacheval(cache, :AppleAccelerateLUFactorization)
90+
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
91+
m, n = size(A, 1), size(A, 2)
92+
if m > n
93+
Bc = copy(cache.b)
94+
aa_getrs!('N', A.factors, A.ipiv, Bc; info)
95+
return copyto!(cache.u, 1, Bc, 1, n)
96+
else
97+
copyto!(cache.u, cache.b)
98+
aa_getrs!('N', A.factors, A.ipiv, cache.u; info)
99+
end
100+
101+
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
102+
end

test/basictests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ end
213213
test_interface(alg, prob1, prob2)
214214
end
215215
end
216+
if LinearSolve.appleaccelerate_isavailable()
217+
test_interface(AppleAccelerateLUFactorization(), prob1, prob2)
218+
end
216219
end
217220

218221
@testset "Generic Factorizations" begin

test/resolve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ using LinearSolve, LinearAlgebra, SparseArrays, InteractiveUtils, Test
22

33
for alg in subtypes(LinearSolve.AbstractFactorization)
44
@show alg
5-
if !(alg in [DiagonalFactorization, CudaOffloadFactorization])
5+
if !(alg in [DiagonalFactorization, CudaOffloadFactorization, AppleAccelerateLUFactorization]) &&
6+
(!(alg == AppleAccelerateLUFactorization) || LinearSolve.appleaccelerate_isavailable())
7+
68
A = [1.0 2.0; 3.0 4.0]
79
alg in [KLUFactorization, UMFPACKFactorization, SparspakFactorization] &&
810
(A = sparse(A))

0 commit comments

Comments
 (0)