Skip to content

Commit a53f644

Browse files
Merge pull request #387 from SciML/mklfactorization_default
Make MKL the default when it's available
2 parents e487d3b + 0a4f96a commit a53f644

File tree

7 files changed

+43
-38
lines changed

7 files changed

+43
-38
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
1616
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1717
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1818
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
19+
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
1920
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2021
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2122
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
@@ -37,7 +38,6 @@ HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3738
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
3839
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3940
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
40-
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
4141
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4242
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4343

@@ -49,7 +49,6 @@ LinearSolveHYPREExt = "HYPRE"
4949
LinearSolveIterativeSolversExt = "IterativeSolvers"
5050
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
5151
LinearSolveKrylovKitExt = "KrylovKit"
52-
LinearSolveMKLExt = "MKL_jll"
5352
LinearSolveMetalExt = "Metal"
5453
LinearSolvePardisoExt = "Pardiso"
5554

@@ -91,7 +90,6 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
9190
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
9291
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
9392
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
94-
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
9593
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
9694
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
9795
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
@@ -101,4 +99,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
10199
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
102100

103101
[targets]
104-
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]
102+
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff"]

src/LinearSolve.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,16 @@ PrecompileTools.@recompile_invalidations begin
4040
import Krylov
4141

4242
using SciMLBase
43+
44+
using MKL_jll
4345
end
4446

4547
using Reexport
4648
@reexport using SciMLBase
4749
using SciMLBase: _unwrap_val
4850

51+
const usemkl = MKL_jll.is_available()
52+
4953
abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
5054
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
5155
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
@@ -91,6 +95,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
9195
CholeskyFactorization
9296
NormalCholeskyFactorization
9397
AppleAccelerateLUFactorization
98+
MKLLUFactorization
9499
end
95100

96101
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
@@ -100,6 +105,7 @@ end
100105
include("common.jl")
101106
include("factorization.jl")
102107
include("appleaccelerate.jl")
108+
include("mkl.jl")
103109
include("simplelu.jl")
104110
include("simplegmres.jl")
105111
include("iterative_wrappers.jl")

src/default.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
needs_concrete_A(alg::DefaultLinearSolver) = true
22
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
3-
T13, T14, T15, T16, T17}
3+
T13, T14, T15, T16, T17, T18}
44
LUFactorization::T1
55
QRFactorization::T2
66
DiagonalFactorization::T3
@@ -18,6 +18,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
1818
CholeskyFactorization::T15
1919
NormalCholeskyFactorization::T16
2020
AppleAccelerateLUFactorization::T17
21+
MKLLUFactorization::T18
2122
end
2223

2324
# Legacy fallback
@@ -162,19 +163,24 @@ function defaultalg(A, b, assump::OperatorAssumptions)
162163
DefaultAlgorithmChoice.GenericLUFactorization
163164
elseif VERSION >= v"1.8" && appleaccelerate_isavailable()
164165
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
165-
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) &&
166+
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
167+
(usemkl && length(b) <= 200)) &&
166168
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
167169
eltype(A) <: Union{Float32, Float64})
168170
DefaultAlgorithmChoice.RFLUFactorization
169171
#elseif A === nothing || A isa Matrix
170172
# alg = FastLUFactorization()
173+
elseif usemkl
174+
DefaultAlgorithmChoice.MKLLUFactorization
171175
else
172-
DefaultAlgorithmChoice.GenericLUFactorization
176+
DefaultAlgorithmChoice.LUFactorization
173177
end
174178
elseif __conditioning(assump) === OperatorCondition.VeryIllConditioned
175179
DefaultAlgorithmChoice.QRFactorization
176180
elseif __conditioning(assump) === OperatorCondition.SuperIllConditioned
177181
DefaultAlgorithmChoice.SVDFactorization
182+
elseif usemkl
183+
DefaultAlgorithmChoice.MKLLUFactorization
178184
else
179185
DefaultAlgorithmChoice.LUFactorization
180186
end
@@ -209,6 +215,8 @@ function algchoice_to_alg(alg::Symbol)
209215
LDLtFactorization()
210216
elseif alg === :LUFactorization
211217
LUFactorization()
218+
elseif alg === :MKLLUFactorization
219+
MKLLUFactorization()
212220
elseif alg === :QRFactorization
213221
QRFactorization()
214222
elseif alg === :DiagonalFactorization

src/extension_algs.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -339,16 +339,6 @@ A wrapper over the IterativeSolvers.jl MINRES.
339339
"""
340340
function IterativeSolversJL_MINRES end
341341

342-
"""
343-
```julia
344-
MKLLUFactorization()
345-
```
346-
347-
A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace
348-
to avoid allocations and does not require libblastrampoline.
349-
"""
350-
struct MKLLUFactorization <: AbstractFactorization end
351-
352342
"""
353343
```julia
354344
MetalLUFactorization()

src/init.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@ function __init__()
1212
@require KrylovKit="0b1a1467-8014-51b9-945f-bf0ae24f4b77" begin
1313
include("../ext/LinearSolveKrylovKitExt.jl")
1414
end
15-
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
16-
include("../ext/LinearSolveMKLExt.jl")
17-
end
1815
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
1916
include("../ext/LinearSolveEnzymeExt.jl")
2017
end

ext/LinearSolveMKLExt.jl renamed to src/mkl.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
module LinearSolveMKLExt
1+
"""
2+
```julia
3+
MKLLUFactorization()
4+
```
25
3-
using MKL_jll
4-
using LinearAlgebra: BlasInt, LU
5-
using LinearAlgebra.LAPACK: require_one_based_indexing,
6-
chkfinite, chkstride1,
7-
@blasfunc, chkargsok
8-
using LinearAlgebra
9-
const usemkl = MKL_jll.is_available()
10-
11-
using LinearSolve
12-
using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase
6+
A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace
7+
to avoid allocations and does not require libblastrampoline.
8+
"""
9+
struct MKLLUFactorization <: AbstractFactorization end
1310

1411
function getrf!(A::AbstractMatrix{<:Float64};
1512
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
@@ -104,10 +101,15 @@ end
104101
default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false
105102
default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false
106103

107-
function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
104+
const PREALLOCATED_MKL_LU = begin
105+
A = rand(0, 0)
106+
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
107+
end
108+
109+
function init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr,
108110
maxiters::Int, abstol, reltol, verbose::Bool,
109111
assumptions::OperatorAssumptions)
110-
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), Ref{BlasInt}()
112+
PREALLOCATED_MKL_LU
111113
end
112114

113115
function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
@@ -140,6 +142,4 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
140142
141143
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
142144
=#
143-
end
144-
145-
end
145+
end

test/default_algs.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@ solve(prob)
99
prob = LinearProblem(rand(50, 50), rand(50))
1010
solve(prob)
1111

12-
@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
13-
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
12+
if LinearSolve.usemkl
13+
@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
14+
LinearSolve.DefaultAlgorithmChoice.MKLLUFactorization
15+
else
16+
@test LinearSolve.defaultalg(nothing, zeros(600)).alg ===
17+
LinearSolve.DefaultAlgorithmChoice.LUFactorization
18+
end
19+
1420
prob = LinearProblem(rand(600, 600), rand(600))
1521
solve(prob)
1622

0 commit comments

Comments
 (0)