Skip to content

Commit 843fac9

Browse files
committed
one test passing.
1 parent 1cb3bc9 commit 843fac9

File tree

4 files changed

+45
-31
lines changed

4 files changed

+45
-31
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ julia = "1.6"
3333
DiffEqProblemLibrary = "a077e3f3-b75c-5d7f-a0c6-6bc4c8ec64a9"
3434
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
3535
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
36+
Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
3637

3738
[targets]
38-
test = ["Test", "OrdinaryDiffEq", "DiffEqProblemLibrary"]
39+
test = ["Test", "OrdinaryDiffEq", "DiffEqProblemLibrary", "Pardiso"]

src/common.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,no
5656

5757
function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorithm,Nothing}, args...;
5858
alias_A = false, alias_b = false,
59-
abstol=eps(eltype(prob.A)),
60-
reltol=eps(eltype(prob.A)),
61-
maxiters=length(prob.b),
59+
abstol=eps(eltype(prob.A)), # TODO handle when eps()
60+
reltol=eps(eltype(prob.A)), # not defined on eltype
61+
maxiters=length(prob.b), # like Int of Complex
6262
verbose=false,
6363
Pl = nothing,
6464
Pr = nothing,

src/pardiso.jl

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,35 @@ import Pardiso
55

66
export PardisoJL
77

8-
struct PardisoJL{A} <: SciMLLinearSolveAlgorithm
9-
nthreads::Union{Int, Nothing}
10-
solver_type::Union{Int, Pardiso.Solver, Nothing}
11-
matrix_type::Union{Int, Pardiso.MatrixType, Nothing}
12-
solve_phase::Union{Int, Pardiso.Phase, Nothing}
13-
release_phase::Union{Int, Nothing}
14-
iparm::Union{A, Nothing}
15-
dparm::Union{A, Nothing}
16-
end
17-
18-
function PardisoJL(solver_type=Pardiso.nothing,)
19-
20-
return PardisoJL(nthreads, solver_type, matrix_type, solve_phase,
21-
release_phase, iparm, dparm)
8+
Base.@kwdef struct PardisoJL <: SciMLLinearSolveAlgorithm
9+
nprocs::Union{Int, Nothing} = nothing
10+
solver_type::Union{Int, Pardiso.Solver, Nothing} = nothing
11+
matrix_type::Union{Int, Pardiso.MatrixType, Nothing} = nothing
12+
solve_phase::Union{Int, Pardiso.Phase, Nothing} = nothing
13+
release_phase::Union{Int, Nothing} = nothing
14+
iparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
15+
dparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
2216
end
2317

2418
function init_cacheval(alg::PardisoJL, cache::LinearCache)
25-
@unpack nthreads, solver_type, matrix_type, iparm, dparm = alg
19+
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
2620

27-
solver = Pardiso.PARDISO_LOADED[] ? PardisoSolver() : MKLPardisoSolver()
21+
solver = Pardiso.PARDISO_LOADED[] ? Pardiso.PardisoSolver() : Pardiso.MKLPardisoSolver()
2822

2923
Pardiso.pardisoinit(solver) # default initialization
3024

31-
nthreads !== nothing && Pardiso.set_nprocs!(ps, nthreads)
25+
nprocs !== nothing && Pardiso.set_nprocs!(ps, nprocs)
3226
solver_type !== nothing && Pardiso.set_solver!(solver, key)
3327
matrix_type !== nothing && Pardiso.set_matrixtype!(solver, matrix_type)
3428
cache.verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
3529

36-
iparm !== nothing && begin # pass in vector of tuples like [(iparm, key)]
30+
if iparm !== nothing # pass in vector of tuples like [(iparm, key)]
3731
for i in length(iparm)
3832
Pardiso.set_iparm!(solver, iparm[i]...)
3933
end
4034
end
4135

42-
dparm !== nothing && begin
36+
if dparm !== nothing
4337
for i in length(dparm)
4438
Pardiso.set_dparm!(solver, dparm[i]...)
4539
end
@@ -49,7 +43,10 @@ function init_cacheval(alg::PardisoJL, cache::LinearCache)
4943
end
5044

5145
function SciMLBase.solve(cache::LinearCache, alg::PardisoJL; kwargs...)
52-
@unpack A, b, u, cacheval = cache
46+
@unpack A, b, u = cache
47+
if A isa DiffEqArrayOperator
48+
A = A.A
49+
end
5350

5451
if cache.isfresh
5552
solver = init_cacheval(alg, cache)
@@ -58,15 +55,15 @@ function SciMLBase.solve(cache::LinearCache, alg::PardisoJL; kwargs...)
5855

5956
abstol = cache.abstol
6057
reltol = cache.reltol
61-
kwargs = (abstol=abstol, reltol=reltol, alg.kwargs...)
58+
kwargs = (abstol=abstol, reltol=reltol)
6259

6360
"""
6461
figure out whatever phase is. should set_phase call be in init_cacheval?
6562
can we use phase to store factorization in cache?
6663
"""
67-
Pardiso.set_phase!(cacheval, alg.solve_phase)
68-
Pardiso.solve!(cacheval, u, A, b)
69-
Pardiso.set_phase!(cacheval, alg.release_phase) # is this necessary?
64+
alg.solve_phase !== nothing && Pardiso.set_phase!(cacheval, alg.solve_phase)
65+
Pardiso.solve!(cache.cacheval, u, A, b)
66+
alg.release_phase !== nothing && Pardiso.set_phase!(cacheval, alg.release_phase) # is this necessary?
7067

7168
return cache.u
7269
end

test/runtests.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,25 @@ end
126126
end
127127
end
128128

129+
@testset "PardisoJL" begin
130+
#@test_broken alg = PardisoJL()
131+
132+
using Pardiso, SparseArrays
133+
verbose = true
134+
135+
A = sparse([ 1. 0 -2 3
136+
0 5 1 2
137+
-2 1 4 -7
138+
3 2 -7 5 ])
139+
b = rand(4)
140+
141+
prob = LinearProblem(A, b)
142+
alg = PardisoJL()
143+
144+
u = solve(prob, alg; verbose=true)
145+
146+
end
147+
129148
@testset "Preconditioners" begin
130149
@testset "scaling_preconditioner" begin
131150
s = rand()
@@ -174,7 +193,4 @@ end
174193
end
175194
end
176195

177-
@testset "PardisoJL" begin
178-
end
179-
180196
end # testset

0 commit comments

Comments
 (0)