Skip to content

Commit fe0d790

Browse files
committed
Try to fix #497
Pardiso defaults for highly indefinite matrices. This commit essentially reverts #89 and introduces a new kwarg "cache_analysis" (default `false`) to PardisoJL() which, if true would lead to the behaviour of #89. Also, allow the user to overwrite all iparms modified by the extension besides of 12.
1 parent 270b56d commit fe0d790

File tree

3 files changed

+68
-29
lines changed

3 files changed

+68
-29
lines changed

ext/LinearSolvePardisoExt.jl

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
2222
reltol,
2323
verbose::Bool,
2424
assumptions::LinearSolve.OperatorAssumptions)
25-
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
25+
@unpack nprocs, solver_type, matrix_type, cache_analysis, iparm, dparm = alg
2626
A = convert(AbstractMatrix, A)
2727

2828
solver = if Pardiso.PARDISO_LOADED[]
@@ -52,22 +52,6 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
5252
end
5353
verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
5454

55-
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
56-
if iparm !== nothing
57-
for i in iparm
58-
Pardiso.set_iparm!(solver, i...)
59-
end
60-
end
61-
62-
if dparm !== nothing
63-
for d in dparm
64-
Pardiso.set_dparm!(solver, d...)
65-
end
66-
end
67-
68-
# Make sure to say it's transposed because its CSC not CSR
69-
Pardiso.set_iparm!(solver, 12, 1)
70-
7155
#=
7256
Note: It is recommended to use IPARM(11)=1 (scaling) and IPARM(13)=1 (matchings) for
7357
highly indefinite symmetric matrices e.g. from interior point optimizations or saddle point problems.
@@ -79,10 +63,10 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
7963
be changed to Pardiso.ANALYSIS_NUM_FACT in the solver loop otherwise instabilities
8064
occur in the example https://github.com/SciML/OrdinaryDiffEq.jl/issues/1569
8165
=#
82-
Pardiso.set_iparm!(solver, 11, 0)
83-
Pardiso.set_iparm!(solver, 13, 0)
84-
85-
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
66+
if cache_analysis
67+
Pardiso.set_iparm!(solver, 11, 0)
68+
Pardiso.set_iparm!(solver, 13, 0)
69+
end
8670

8771
if alg.solver_type == 1
8872
# PARDISO uses a numerical factorization A = LU for the first system and
@@ -92,10 +76,30 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
9276
Pardiso.set_iparm!(solver, 3, round(Int, abs(log10(reltol)), RoundDown) * 10 + 1)
9377
end
9478

95-
Pardiso.pardiso(solver,
96-
u,
97-
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
98-
b)
79+
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
80+
if iparm !== nothing
81+
for i in iparm
82+
Pardiso.set_iparm!(solver, i...)
83+
end
84+
end
85+
86+
if dparm !== nothing
87+
for d in dparm
88+
Pardiso.set_dparm!(solver, d...)
89+
end
90+
end
91+
92+
# Make sure to say it's transposed because its CSC not CSR
93+
# This is also the only value which should not be overwritten by users
94+
Pardiso.set_iparm!(solver, 12, 1)
95+
96+
if cache_analysis
97+
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
98+
Pardiso.pardiso(solver,
99+
u,
100+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
101+
b)
102+
end
99103

100104
return solver
101105
end
@@ -105,7 +109,8 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs
105109
A = convert(AbstractMatrix, A)
106110

107111
if cache.isfresh
108-
Pardiso.set_phase!(cache.cacheval, Pardiso.NUM_FACT)
112+
phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT
113+
Pardiso.set_phase!(cache.cacheval, phase)
109114
Pardiso.pardiso(cache.cacheval, A, eltype(A)[])
110115
cache.isfresh = false
111116
end

src/extension_algs.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,14 @@ struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
154154
nprocs::Union{Int, Nothing}
155155
solver_type::T1
156156
matrix_type::T2
157+
cache_analysis::Bool
157158
iparm::Union{Vector{Tuple{Int, Int}}, Nothing}
158159
dparm::Union{Vector{Tuple{Int, Int}}, Nothing}
159160

160161
function PardisoJL(; nprocs::Union{Int, Nothing} = nothing,
161162
solver_type = nothing,
162163
matrix_type = nothing,
164+
cache_analysis = false,
163165
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
164166
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing)
165167
ext = Base.get_extension(@__MODULE__, :LinearSolvePardisoExt)
@@ -170,7 +172,8 @@ struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
170172
T2 = typeof(matrix_type)
171173
@assert T1 <: Union{Int, Nothing, ext.Pardiso.Solver}
172174
@assert T2 <: Union{Int, Nothing, ext.Pardiso.MatrixType}
173-
return new{T1, T2}(nprocs, solver_type, matrix_type, iparm, dparm)
175+
return new{T1, T2}(
176+
nprocs, solver_type, matrix_type, cache_analysis, iparm, dparm)
174177
end
175178
end
176179
end

test/pardiso/pardiso.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearSolve, SparseArrays, Random
1+
using LinearSolve, SparseArrays, Random, LinearAlgebra
22
import Pardiso
33

44
A1 = sparse([1.0 0 -2 3
@@ -13,12 +13,22 @@ n = 4
1313
e = ones(n)
1414
e2 = ones(n - 1)
1515
A2 = spdiagm(-1 => im * e2, 0 => lambda * e, 1 => -im * e2)
16+
1617
b2 = rand(n) + im * zeros(n)
1718
cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)
1819

1920
prob2 = LinearProblem(A2, b2)
2021

21-
for alg in (PardisoJL(), MKLPardisoFactorize(), MKLPardisoIterate())
22+
for alg in (PardisoJL(), MKLPardisoFactorize())
23+
u = solve(prob1, alg; cache_kwargs...).u
24+
@test A1 * u b1
25+
26+
u = solve(prob2, alg; cache_kwargs...).u
27+
@test eltype(u) <: Complex
28+
@test A2 * u b2
29+
end
30+
31+
for alg in (MKLPardisoIterate(),)
2232
u = solve(prob1, alg; cache_kwargs...).u
2333
@test A1 * u b1
2434

@@ -27,6 +37,8 @@ for alg in (PardisoJL(), MKLPardisoFactorize(), MKLPardisoIterate())
2737
@test_broken A2 * u b2
2838
end
2939

40+
41+
3042
Random.seed!(10)
3143
A = sprand(n, n, 0.8);
3244
A2 = 2.0 .* A;
@@ -53,6 +65,25 @@ sol33 = solve(linsolve)
5365
@test sol12.u sol32.u
5466
@test sol13.u sol33.u
5567

68+
69+
# Test for problem from #497
70+
function makeA()
71+
n = 60
72+
colptr = [1, 4, 7, 11, 15, 17, 22, 26, 30, 34, 38, 40, 46, 50, 54, 58, 62, 64, 70, 74, 78, 82, 86, 88, 94, 98, 102, 106, 110, 112, 118, 122, 126, 130, 134, 136, 142, 146, 150, 154, 158, 160, 166, 170, 174, 178, 182, 184, 190, 194, 198, 202, 206, 208, 214, 218, 222, 224, 226, 228, 232]
73+
rowval = [1, 3, 4, 1, 2, 4, 2, 4, 9, 10, 3, 5, 11, 12, 1, 3, 2, 4, 6, 11, 12, 2, 7, 9, 10, 2, 7, 8, 10, 8, 10, 15, 16, 9, 11, 17, 18, 7, 9, 2, 8, 10, 12, 17, 18, 8, 13, 15, 16, 8, 13, 14, 16, 14, 16, 21, 22, 15, 17, 23, 24, 13, 15, 8, 14, 16, 18, 23, 24, 14, 19, 21, 22, 14, 19, 20, 22, 20, 22, 27, 28, 21, 23, 29, 30, 19, 21, 14, 20, 22, 24, 29, 30, 20, 25, 27, 28, 20, 25, 26, 28, 26, 28, 33, 34, 27, 29, 35, 36, 25, 27, 20, 26, 28, 30, 35, 36, 26, 31, 33, 34, 26, 31, 32, 34, 32, 34, 39, 40, 33, 35, 41, 42, 31, 33, 26, 32, 34, 36, 41, 42, 32, 37, 39, 40, 32, 37, 38, 40, 38, 40, 45, 46, 39, 41, 47, 48, 37, 39, 32, 38, 40, 42, 47, 48, 38, 43, 45, 46, 38, 43, 44, 46, 44, 46, 51, 52, 45, 47, 53, 54, 43, 45, 38, 44, 46, 48, 53, 54, 44, 49, 51, 52, 44, 49, 50, 52, 50, 52, 57, 58, 51, 53, 59, 60, 49, 51, 44, 50, 52, 54, 59, 60, 50, 55, 57, 58, 50, 55, 56, 58, 56, 58, 57, 59, 55, 57, 50, 56, 58, 60]
74+
nzval = [-0.64, 1.0, -1.0, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -1.0806825309567203, 1.0, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0]
75+
A = SparseMatrixCSC(n, n, colptr, rowval, nzval)
76+
return(A)
77+
end
78+
79+
A=makeA()
80+
u0=fill(0.1,size(A,2))
81+
linprob = LinearProblem(A, A*u0)
82+
u = LinearSolve.solve(linprob, PardisoJL(),verbose=true)
83+
@test norm(u-u0) < 1.0e-14
84+
85+
86+
5687
# Testing and demonstrating Pardiso.set_iparm! for MKLPardisoSolver
5788
solver = Pardiso.MKLPardisoSolver()
5889
iparm = [

0 commit comments

Comments
 (0)