Skip to content

Commit 1ee467b

Browse files
Merge pull request #349 from SciML/mkl
Setup MKL direct factorizations
2 parents 98a2292 + aaf64d3 commit 1ee467b

File tree

5 files changed

+70
-2
lines changed

5 files changed

+70
-2
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2929
[weakdeps]
3030
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3131
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
32+
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
3233
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
3334
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
3435
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
@@ -37,6 +38,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
3738
LinearSolveCUDAExt = "CUDA"
3839
LinearSolveHYPREExt = "HYPRE"
3940
LinearSolveIterativeSolversExt = "IterativeSolvers"
41+
LinearSolveMKLExt = "MKL_jll"
4042
LinearSolveKrylovKitExt = "KrylovKit"
4143
LinearSolvePardisoExt = "Pardiso"
4244

@@ -70,6 +72,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7072
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
7173
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
7274
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
75+
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
7376
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
7477
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
7578
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -78,4 +81,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
7881
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7982

8083
[targets]
81-
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI"]
84+
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll"]

ext/LinearSolveMKLExt.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
module LinearSolveMKLExt
2+
3+
using MKL_jll
4+
using LinearAlgebra: BlasInt, LU
5+
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
6+
@blasfunc, chkargsok
7+
using LinearAlgebra
8+
const usemkl = MKL_jll.is_available()
9+
10+
using LinearSolve
11+
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase
12+
13+
function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false)
14+
require_one_based_indexing(A)
15+
check && chkfinite(A)
16+
chkstride1(A)
17+
m, n = size(A)
18+
lda = max(1,stride(A, 2))
19+
ccall((@blasfunc(dgetrf_), MKL_jll.libmkl_rt), Cvoid,
20+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
21+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
22+
m, n, A, lda, ipiv, info)
23+
chkargsok(info[])
24+
A, ipiv, info[] #Error code is stored in LU factorization type
25+
end
26+
27+
default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
28+
default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false
29+
30+
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
31+
maxiters::Int, abstol, reltol, verbose::Bool,
32+
assumptions::OperatorAssumptions)
33+
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
34+
end
35+
36+
function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
37+
kwargs...)
38+
A = cache.A
39+
A = convert(AbstractMatrix, A)
40+
if cache.isfresh
41+
cacheval = @get_cacheval(cache, :MKLLUFactorization)
42+
fact = LU(getrf!(A)...)
43+
cache.cacheval = fact
44+
cache.isfresh = false
45+
end
46+
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization), cache.b)
47+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
48+
end
49+
50+
end

src/LinearSolve.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ end
124124
@require KrylovKit="0b1a1467-8014-51b9-945f-bf0ae24f4b77" begin
125125
include("../ext/LinearSolveKrylovKitExt.jl")
126126
end
127+
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
128+
include("../ext/LinearSolveMKLExt.jl")
129+
end
127130
end
128131
end
129132

@@ -181,6 +184,7 @@ export HYPREAlgorithm
181184
export CudaOffloadFactorization
182185
export MKLPardisoFactorize, MKLPardisoIterate
183186
export PardisoJL
187+
export MKLLUFactorization
184188

185189
export OperatorAssumptions, OperatorCondition
186190

src/extension_algs.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,13 @@ A wrapper over the IterativeSolvers.jl MINRES.
337337
338338
"""
339339
function IterativeSolversJL_MINRES end
340+
341+
"""
342+
```julia
343+
MKLLUFactorization()
344+
```
345+
346+
A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace
347+
to avoid allocations and does not require libblastrampoline.
348+
"""
349+
struct MKLLUFactorization <: AbstractFactorization end

test/basictests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
22
using SciMLOperators
3-
using IterativeSolvers, KrylovKit
3+
using IterativeSolvers, KrylovKit, MKL_jll
44
using Test
55
import Random
66

@@ -207,6 +207,7 @@ end
207207
QRFactorization(),
208208
SVDFactorization(),
209209
RFLUFactorization(),
210+
MKLLUFactorization(),
210211
LinearSolve.defaultalg(prob1.A, prob1.b))
211212
@testset "$alg" begin
212213
test_interface(alg, prob1, prob2)

0 commit comments

Comments
 (0)