Skip to content

Commit 683d5c3

Browse files
WIP: Setup MKL direct factorizations
MWE: ```julia using LinearSolve, MKL_jll A = rand(4, 4); b = rand(4); u0 = zeros(4); lp = LinearProblem(A, b); truesol = solve(lp, LUFactorization()) mklsol = solve(lp, MKLLUFactorization()) @test truesol ≈ mklsol ``` The segfault can be reproduced just with the triangular solver. MWE without LinearSolve: ```julia using MKL_jll using LinearAlgebra: BlasInt, LU using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, chkargsok const usemkl = MKL_jll.is_available() function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false) require_one_based_indexing(A) check && chkfinite(A) chkstride1(A) m, n = size(A) lda = max(1,stride(A, 2)) ccall((:dgetrf_, MKL_jll.libmkl_rt), Cvoid, (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), m, n, A, lda, ipiv, info) chkargsok(info[]) A, ipiv, info[] #Error code is stored in LU factorization type end A = rand(4,4); b = rand(4) getrf!(A) LU(getrf!(A)...) \ b ```
1 parent 98a2292 commit 683d5c3

File tree

4 files changed

+67
-0
lines changed

4 files changed

+67
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2929
[weakdeps]
3030
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3131
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
32+
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
3233
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
3334
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
3435
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
@@ -37,6 +38,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
3738
LinearSolveCUDAExt = "CUDA"
3839
LinearSolveHYPREExt = "HYPRE"
3940
LinearSolveIterativeSolversExt = "IterativeSolvers"
41+
LinearSolveMKLExt = "MKL_jll"
4042
LinearSolveKrylovKitExt = "KrylovKit"
4143
LinearSolvePardisoExt = "Pardiso"
4244

ext/LinearSolveMKLExt.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
module LinearSolveMKLExt
2+
3+
using MKL_jll
4+
using LinearAlgebra: BlasInt, LU
5+
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, chkargsok
6+
using LinearAlgebra
7+
const usemkl = MKL_jll.is_available()
8+
9+
using LinearSolve
10+
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase
11+
12+
function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false)
13+
require_one_based_indexing(A)
14+
check && chkfinite(A)
15+
chkstride1(A)
16+
m, n = size(A)
17+
lda = max(1,stride(A, 2))
18+
19+
if isempty(ipiv)
20+
ipiv = similar(A, BlasInt, min(size(A,1),size(A,2)))
21+
end
22+
23+
ccall((:dgetrf_, MKL_jll.libmkl_rt), Cvoid,
24+
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
25+
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
26+
m, n, A, lda, ipiv, info)
27+
chkargsok(info[])
28+
A, ipiv, info[] #Error code is stored in LU factorization type
29+
end
30+
31+
default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
32+
default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false
33+
34+
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
35+
maxiters::Int, abstol, reltol, verbose::Bool,
36+
assumptions::OperatorAssumptions)
37+
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
38+
end
39+
40+
function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
41+
kwargs...)
42+
A = cache.A
43+
A = convert(AbstractMatrix, A)
44+
if cache.isfresh
45+
cacheval = @get_cacheval(cache, :MKLLUFactorization)
46+
fact = LU(getrf!(A)...)
47+
cache.cacheval = fact
48+
cache.isfresh = false
49+
end
50+
y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization), cache.b)
51+
SciMLBase.build_linear_solution(alg, y, nothing, cache)
52+
end
53+
54+
end

src/LinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ export HYPREAlgorithm
181181
export CudaOffloadFactorization
182182
export MKLPardisoFactorize, MKLPardisoIterate
183183
export PardisoJL
184+
export MKLLUFactorization
184185

185186
export OperatorAssumptions, OperatorCondition
186187

src/extension_algs.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,13 @@ A wrapper over the IterativeSolvers.jl MINRES.
337337
338338
"""
339339
function IterativeSolversJL_MINRES end
340+
341+
"""
342+
```julia
343+
MKLLUFactorization()
344+
```
345+
346+
A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace
347+
to avoid allocations and does not require libblastrampoline.
348+
"""
349+
struct MKLLUFactorization <: AbstractFactorization end

0 commit comments

Comments
 (0)