Skip to content

Commit 5c41c43

Browse files
Make MKL the default when it's available
The benchmarks have pretty conclusively shown that MKL's LU factorization is just so much better than OpenBLAS that we should effectively always use it. What this does is make MKL_jll into a dependency of LinearSolve.jl and then uses the direct calls to the binary as part of the default algorithm when it's available (it won't be available on systems where MKL does not exist, like M2 macbooks). This uses the direct calls instead of LibBLASTrampoline and thus does not effect the user's global state, thus only being a local change that simply accelerates packages using LinearSolve (i.e. all of SciML).
1 parent e487d3b commit 5c41c43

File tree

4 files changed

+18
-23
lines changed

4 files changed

+18
-23
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
1616
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1717
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1818
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
19+
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
1920
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2021
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2122
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
@@ -37,7 +38,6 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3738
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
3839
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3940
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
40-
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
4141
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4242
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4343

@@ -49,7 +49,6 @@ LinearSolveHYPREExt = "HYPRE"
4949
LinearSolveIterativeSolversExt = "IterativeSolvers"
5050
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
5151
LinearSolveKrylovKitExt = "KrylovKit"
52-
LinearSolveMKLExt = "MKL_jll"
5352
LinearSolveMetalExt = "Metal"
5453
LinearSolvePardisoExt = "Pardiso"
5554

@@ -91,7 +90,6 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
9190
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
9291
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
9392
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
94-
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
9593
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
9694
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
9795
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
@@ -101,4 +99,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
10199
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
102100

103101
[targets]
104-
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]
102+
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff"]

src/LinearSolve.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,16 @@ PrecompileTools.@recompile_invalidations begin
4040
import Krylov
4141

4242
using SciMLBase
43+
44+
using MKL_jll
4345
end
4446

4547
using Reexport
4648
@reexport using SciMLBase
4749
using SciMLBase: _unwrap_val
4850

51+
const usemkl = MKL_jll.is_available()
52+
4953
abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
5054
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
5155
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end

src/default.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
needs_concrete_A(alg::DefaultLinearSolver) = true
22
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3-
T13, T14, T15, T16, T17}
3+
T13, T14, T15, T16, T17, T18}
44
LUFactorization::T1
55
QRFactorization::T2
66
DiagonalFactorization::T3
@@ -18,6 +18,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
1818
CholeskyFactorization::T15
1919
NormalCholeskyFactorization::T16
2020
AppleAccelerateLUFactorization::T17
21+
MKLLUFactorization::T18
2122
end
2223

2324
# Legacy fallback
@@ -162,19 +163,24 @@ function defaultalg(A, b, assump::OperatorAssumptions)
162163
DefaultAlgorithmChoice.GenericLUFactorization
163164
elseif VERSION >= v"1.8" && appleaccelerate_isavailable()
164165
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
165-
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
166+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
167+
(usemkl && length(b) <= 200)) &&
166168
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
167169
eltype(A) <: Union{Float32, Float64})
168170
DefaultAlgorithmChoice.RFLUFactorization
169171
#elseif A === nothing || A isa Matrix
170172
# alg = FastLUFactorization()
173+
elseif usemkl
174+
DefaultAlgorithmChoice.MKLLUFactorization
171175
else
172-
DefaultAlgorithmChoice.GenericLUFactorization
176+
DefaultAlgorithmChoice.LUFactorization
173177
end
174178
elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
175179
DefaultAlgorithmChoice.QRFactorization
176180
elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned
177181
DefaultAlgorithmChoice.SVDFactorization
182+
elseif usemkl
183+
DefaultAlgorithmChoice.MKLLUFactorization
178184
else
179185
DefaultAlgorithmChoice.LUFactorization
180186
end
@@ -209,6 +215,8 @@ function algchoice_to_alg(alg::Symbol)
209215
LDLtFactorization()
210216
elseif alg === :LUFactorization
211217
LUFactorization()
218+
elseif alg === :MKLLUFactorization
219+
MKLLUFactorization()
212220
elseif alg === :QRFactorization
213221
QRFactorization()
214222
elseif alg === :DiagonalFactorization

ext/LinearSolveMKLExt.jl renamed to src/mkl.jl

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
module LinearSolveMKLExt
2-
3-
using MKL_jll
4-
using LinearAlgebra: BlasInt, LU
5-
using LinearAlgebra.LAPACK: require_one_based_indexing,
6-
chkfinite, chkstride1,
7-
@blasfunc, chkargsok
8-
using LinearAlgebra
9-
const usemkl = MKL_jll.is_available()
10-
11-
using LinearSolve
12-
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase
13-
141
function getrf!(A::AbstractMatrix{<:Float64};
152
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
163
info = Ref{BlasInt}(),
@@ -140,6 +127,4 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
140127
141128
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
142129
=#
143-
end
144-
145-
end
130+
end

0 commit comments

Comments
 (0)