Skip to content

Commit 137eed4

Browse files
Merge pull request #37 from vpuri3/vp-pardiso
Pardiso
2 parents e99d699 + 2d3822f commit 137eed4

File tree

5 files changed

+142
-4
lines changed

5 files changed

+142
-4
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
1313
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
14+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1415
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1516
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1617
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -32,6 +33,7 @@ julia = "1.6"
3233
DiffEqProblemLibrary = "a077e3f3-b75c-5d7f-a0c6-6bc4c8ec64a9"
3334
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
3435
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
36+
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
3537

3638
[targets]
37-
test = ["Test", "OrdinaryDiffEq", "DiffEqProblemLibrary"]
39+
test = ["Test", "OrdinaryDiffEq", "DiffEqProblemLibrary", "Pardiso"]

src/LinearSolve.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ using SparseArrays
1010
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
1111
using Setfield
1212
using UnPack
13+
using Requires
1314

1415
# wrap
1516
import Krylov
16-
import KrylovKit
17+
import KrylovKit # TODO
1718
import IterativeSolvers
1819

1920
using Reexport
@@ -38,12 +39,13 @@ function __init__()
3839
else
3940
IS_OPENBLAS[] = occursin("openblas", BLAS.get_config().loaded_libs[1].libname)
4041
end
42+
43+
@require Pardiso="46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" include("pardiso.jl")
4144
end
4245

4346
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
4447
RFLUFactorizaation
45-
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB,
46-
KrylovJL_MINRES,
48+
export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
4749
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
4850
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES
4951
export DefaultLinSolve

src/pardiso.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
2+
## Pardiso
3+
4+
import Pardiso
5+
6+
export PardisoJL, MKLPardisoFactorize, MKLPardisoIterate
7+
8+
Base.@kwdef struct PardisoJL <: SciMLLinearSolveAlgorithm
9+
nprocs::Union{Int, Nothing} = nothing
10+
solver_type::Union{Int, Pardiso.Solver, Nothing} = nothing
11+
matrix_type::Union{Int, Pardiso.MatrixType, Nothing} = nothing
12+
fact_phase::Union{Int, Pardiso.Phase, Nothing} = nothing
13+
solve_phase::Union{Int, Pardiso.Phase, Nothing} = nothing
14+
release_phase::Union{Int, Nothing} = nothing
15+
iparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
16+
dparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
17+
end
18+
19+
MKLPardisoFactorize(;kwargs...) = PardisoJL(;fact_phase=Pardiso.NUM_FACT,
20+
solve_phase=Pardiso.SOLVE_ITERATIVE_REFINE,
21+
kwargs...)
22+
MKLPardisoIterate(;kwargs...) = PardisoJL(;solve_phase=Pardiso.NUM_FACT_SOLVE_REFINE,
23+
kwargs...)
24+
25+
# TODO schur complement functionality
26+
27+
function init_cacheval(alg::PardisoJL, cache::LinearCache)
28+
@unpack nprocs, solver_type, matrix_type, fact_phase, solve_phase, iparm, dparm = alg
29+
@unpack A, b, u = cache
30+
31+
if A isa DiffEqArrayOperator
32+
A = A.A
33+
end
34+
35+
solver =
36+
if Pardiso.PARDISO_LOADED[]
37+
solver = Pardiso.PardisoSolver()
38+
solver_type !== nothing && Pardiso.set_solver!(solver, solver_type)
39+
40+
solver
41+
else
42+
solver = Pardiso.MKLPardisoSolver()
43+
nprocs !== nothing && Pardiso.set_nprocs!(solver, nprocs)
44+
45+
solver
46+
end
47+
48+
Pardiso.pardisoinit(solver) # default initialization
49+
50+
matrix_type !== nothing && Pardiso.set_matrixtype!(solver, matrix_type)
51+
cache.verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
52+
53+
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
54+
if iparm !== nothing
55+
for i in length(iparm)
56+
Pardiso.set_iparm!(solver, iparm[i]...)
57+
end
58+
end
59+
60+
if dparm !== nothing
61+
for i in length(dparm)
62+
Pardiso.set_dparm!(solver, dparm[i]...)
63+
end
64+
end
65+
66+
if (fact_phase !== nothing) | (solve_phase !== nothing)
67+
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
68+
Pardiso.pardiso(solver, u, A, b)
69+
end
70+
71+
if fact_phase !== nothing
72+
Pardiso.set_phase!(solver, fact_phase)
73+
Pardiso.pardiso(solver, u, A, b)
74+
end
75+
76+
return solver
77+
end
78+
79+
function SciMLBase.solve(cache::LinearCache, alg::PardisoJL; kwargs...)
80+
@unpack A, b, u = cache
81+
if A isa DiffEqArrayOperator
82+
A = A.A
83+
end
84+
85+
if cache.isfresh
86+
solver = init_cacheval(alg, cache)
87+
cache = set_cacheval(cache, solver)
88+
end
89+
90+
alg.solve_phase !== nothing && Pardiso.set_phase!(cache.cacheval, alg.solve_phase)
91+
Pardiso.pardiso(cache.cacheval, u, A, b)
92+
alg.release_phase !== nothing && Pardiso.set_phase!(cache.cacheval, alg.release_phase)
93+
94+
return cache.u
95+
end

src/wrappers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,4 @@ function SciMLBase.solve(cache::LinearCache, alg::IterativeSolversJL; kwargs...)
297297

298298
return cache.u
299299
end
300+

test/runtests.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,44 @@ end
126126
end
127127
end
128128

129+
@testset "PardisoJL" begin
130+
@test_throws UndefVarError alg = PardisoJL()
131+
132+
using Pardiso, SparseArrays
133+
134+
A1 = sparse([ 1. 0 -2 3
135+
0 5 1 2
136+
-2 1 4 -7
137+
3 2 -7 5 ])
138+
b1 = rand(4)
139+
prob1 = LinearProblem(A1, b1)
140+
141+
lambda = 3
142+
e = ones(n)
143+
e2 = ones(n-1)
144+
A2 = spdiagm(-1 => im*e2, 0 => lambda*e, 1 => -im*e2)
145+
b2 = rand(n) + im * zeros(n)
146+
147+
prob2 = LinearProblem(A2, b2)
148+
149+
for alg in (
150+
PardisoJL(),
151+
MKLPardisoFactorize(),
152+
MKLPardisoIterate(),
153+
)
154+
155+
u = solve(prob1, alg; cache_kwargs...)
156+
@test A1 * u b1
157+
158+
# common interface doesn't support complex types
159+
# https://github.com/SciML/LinearSolve.jl/issues/38
160+
161+
# u = solve(prob2, alg; cache_kwargs...)
162+
# @test A2 * u ≈ b2
163+
end
164+
165+
end
166+
129167
@testset "Preconditioners" begin
130168
@testset "scaling_preconditioner" begin
131169
s = rand()

0 commit comments

Comments
 (0)