Skip to content

Commit 28c0189

Browse files
authored
feat: add PETScSNES (#482)
* feat: add `PETScSNES` * feat: support automatic sparsity detection for PETSc * test: add PETScSNES to the wrapper tests * docs: add PETSc example * test: skip PETSc tests on windows * docs: print the benchmark results
1 parent 21b02bd commit 28c0189

File tree

12 files changed

+384
-40
lines changed

12 files changed

+384
-40
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
4141
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
4242
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
4343
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
44+
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
4445
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
4546
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
47+
PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
4648
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
4749
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
4850
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
@@ -55,6 +57,7 @@ NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
5557
NonlinearSolveMINPACKExt = "MINPACK"
5658
NonlinearSolveNLSolversExt = "NLSolvers"
5759
NonlinearSolveNLsolveExt = ["NLsolve", "LineSearches"]
60+
NonlinearSolvePETScExt = ["PETSc", "MPI"]
5861
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
5962
NonlinearSolveSpeedMappingExt = "SpeedMapping"
6063
NonlinearSolveSundialsExt = "Sundials"
@@ -86,13 +89,15 @@ LineSearches = "7.3"
8689
LinearAlgebra = "1.10"
8790
LinearSolve = "2.35"
8891
MINPACK = "1.2"
92+
MPI = "0.20.22"
8993
MaybeInplace = "0.1.4"
9094
NLSolvers = "0.5"
9195
NLsolve = "4.5"
9296
NaNMath = "1"
9397
NonlinearProblemLibrary = "0.1.2"
9498
NonlinearSolveBase = "1"
9599
OrdinaryDiffEqTsit5 = "1.1.0"
100+
PETSc = "0.2"
96101
Pkg = "1.10"
97102
PrecompileTools = "1.2"
98103
Preferences = "1.4"
@@ -139,6 +144,7 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
139144
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
140145
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
141146
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
147+
PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
142148
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
143149
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
144150
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
@@ -152,4 +158,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
152158
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
153159

154160
[targets]
155-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
161+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]

docs/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1515
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1616
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1717
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
18+
PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
1819
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2021
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -31,6 +32,7 @@ AlgebraicMultigrid = "0.5, 0.6"
3132
ArrayInterface = "6, 7"
3233
BenchmarkTools = "1"
3334
BracketingNonlinearSolve = "1"
35+
DiffEqBase = "6.158"
3436
DifferentiationInterface = "0.6.16"
3537
Documenter = "1"
3638
DocumenterCitations = "1"
@@ -41,6 +43,7 @@ LinearSolve = "2"
4143
NonlinearSolve = "4"
4244
NonlinearSolveBase = "1"
4345
OrdinaryDiffEqTsit5 = "1.1.0"
46+
PETSc = "0.2"
4447
Plots = "1"
4548
Random = "1.10"
4649
SciMLBase = "2.4"

docs/pages.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ pages = [
99
"tutorials/modelingtoolkit.md",
1010
"tutorials/small_compile.md",
1111
"tutorials/iterator_interface.md",
12-
"tutorials/optimizing_parameterized_ode.md"
12+
"tutorials/optimizing_parameterized_ode.md",
13+
"tutorials/snes_ex2.md"
1314
],
1415
"Basics" => Any[
1516
"basics/nonlinear_problem.md",
@@ -45,6 +46,7 @@ pages = [
4546
"api/minpack.md",
4647
"api/nlsolve.md",
4748
"api/nlsolvers.md",
49+
"api/petsc.md",
4850
"api/siamfanlequations.md",
4951
"api/speedmapping.md",
5052
"api/sundials.md"

docs/src/api/petsc.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# PETSc.jl
2+
3+
This is a extension for importing solvers from PETSc.jl SNES into the SciML interface. Note
4+
that these solvers do not come by default, and thus one needs to install the package before
5+
using these solvers:
6+
7+
```julia
8+
using Pkg
9+
Pkg.add("PETSc")
10+
using PETSc, NonlinearSolve
11+
```
12+
13+
## Solver API
14+
15+
```@docs
16+
PETScSNES
17+
```

docs/src/solvers/nonlinear_system_solvers.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,12 @@ This is a wrapper package for importing solvers from NLSolvers.jl into the SciML
177177
[NLSolvers.jl](https://github.com/JuliaNLSolvers/NLSolvers.jl)
178178

179179
For a list of possible solvers see the [NLSolvers.jl documentation](https://julianlsolvers.github.io/NLSolvers.jl/)
180+
181+
### PETSc.jl
182+
183+
This is a wrapper package for importing solvers from PETSc.jl into the SciML interface.
184+
185+
- [`PETScSNES()`](@ref): A wrapper for
186+
[PETSc.jl](https://github.com/JuliaParallel/PETSc.jl)
187+
188+
For a list of possible solvers see the [PETSc.jl documentation](https://petsc.org/release/manual/snes/)

docs/src/tutorials/snes_ex2.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# [PETSc SNES Example 2](@id snes_ex2)
2+
3+
This implements `src/snes/examples/tutorials/ex2.c` from PETSc and `examples/SNES_ex2.jl`
4+
from PETSc.jl using automatic sparsity detection and automatic differentiation using
5+
`NonlinearSolve.jl`.
6+
7+
This solves the equations sequentially. Newton method to solve
8+
`u'' + u^{2} = f`, sequentially.
9+
10+
```@example snes_ex2
11+
using NonlinearSolve, PETSc, LinearAlgebra, SparseConnectivityTracer, BenchmarkTools
12+
13+
u0 = fill(0.5, 128)
14+
15+
function form_residual!(resid, x, _)
16+
n = length(x)
17+
xp = LinRange(0.0, 1.0, n)
18+
F = 6xp .+ (xp .+ 1e-12) .^ 6
19+
20+
dx = 1 / (n - 1)
21+
resid[1] = x[1]
22+
for i in 2:(n - 1)
23+
resid[i] = (x[i - 1] - 2x[i] + x[i + 1]) / dx^2 + x[i] * x[i] - F[i]
24+
end
25+
resid[n] = x[n] - 1
26+
27+
return
28+
end
29+
```
30+
31+
To use automatic sparsity detection, we need to specify `sparsity` keyword argument to
32+
`NonlinearFunction`. See [Automatic Sparsity Detection](@ref sparsity-detection) for more
33+
details.
34+
35+
```@example snes_ex2
36+
nlfunc_dense = NonlinearFunction(form_residual!)
37+
nlfunc_sparse = NonlinearFunction(form_residual!; sparsity = TracerSparsityDetector())
38+
39+
nlprob_dense = NonlinearProblem(nlfunc_dense, u0)
40+
nlprob_sparse = NonlinearProblem(nlfunc_sparse, u0)
41+
```
42+
43+
Now we can solve the problem using `PETScSNES` or with one of the native `NonlinearSolve.jl`
44+
solvers.
45+
46+
```@example snes_ex2
47+
sol_dense_nr = solve(nlprob_dense, NewtonRaphson(); abstol = 1e-8)
48+
sol_dense_snes = solve(nlprob_dense, PETScSNES(); abstol = 1e-8)
49+
sol_dense_nr .- sol_dense_snes
50+
```
51+
52+
```@example snes_ex2
53+
sol_sparse_nr = solve(nlprob_sparse, NewtonRaphson(); abstol = 1e-8)
54+
sol_sparse_snes = solve(nlprob_sparse, PETScSNES(); abstol = 1e-8)
55+
sol_sparse_nr .- sol_sparse_snes
56+
```
57+
58+
As expected the solutions are the same (upto floating point error). Now let's compare the
59+
runtimes.
60+
61+
## Runtimes
62+
63+
### Dense Jacobian
64+
65+
```@example snes_ex2
66+
@benchmark solve($(nlprob_dense), $(NewtonRaphson()); abstol = 1e-8)
67+
```
68+
69+
```@example snes_ex2
70+
@benchmark solve($(nlprob_dense), $(PETScSNES()); abstol = 1e-8)
71+
```
72+
73+
### Sparse Jacobian
74+
75+
```@example snes_ex2
76+
@benchmark solve($(nlprob_sparse), $(NewtonRaphson()); abstol = 1e-8)
77+
```
78+
79+
```@example snes_ex2
80+
@benchmark solve($(nlprob_sparse), $(PETScSNES()); abstol = 1e-8)
81+
```

ext/NonlinearSolvePETScExt.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
module NonlinearSolvePETScExt
2+
3+
using FastClosures: @closure
4+
using MPI: MPI
5+
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
6+
using NonlinearSolve: NonlinearSolve, PETScSNES
7+
using PETSc: PETSc
8+
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
9+
using SparseArrays: AbstractSparseMatrix
10+
11+
function SciMLBase.__solve(
12+
prob::NonlinearProblem, alg::PETScSNES, args...; abstol = nothing, reltol = nothing,
13+
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
14+
show_trace::Val{ShT} = Val(false), kwargs...) where {ShT}
15+
# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
16+
termination_condition === nothing ||
17+
error("`PETScSNES` does not support termination conditions!")
18+
19+
_f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
20+
T = eltype(prob.u0)
21+
@assert T PETSc.scalar_types
22+
23+
if alg.petsclib === missing
24+
petsclibidx = findfirst(PETSc.petsclibs) do petsclib
25+
petsclib isa PETSc.PetscLibType{T}
26+
end
27+
28+
if petsclibidx === nothing
29+
error("No compatible PETSc library found for element type $(T). Pass in a \
30+
custom `petsclib` via `PETScSNES(; petsclib = <petsclib>, ....)`.")
31+
end
32+
petsclib = PETSc.petsclibs[petsclibidx]
33+
else
34+
petsclib = alg.petsclib
35+
end
36+
PETSc.initialized(petsclib) || PETSc.initialize(petsclib)
37+
38+
abstol = get_tolerance(abstol, T)
39+
reltol = get_tolerance(reltol, T)
40+
41+
nf = Ref{Int}(0)
42+
43+
f! = @closure (cfx, cx, user_ctx) -> begin
44+
nf[] += 1
45+
fx = cfx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cfx; read = false) : cfx
46+
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
47+
_f!(fx, x)
48+
Base.finalize(fx)
49+
Base.finalize(x)
50+
return
51+
end
52+
53+
snes = PETSc.SNES{T}(petsclib,
54+
alg.mpi_comm === missing ? MPI.COMM_SELF : alg.mpi_comm;
55+
alg.snes_options..., snes_monitor = ShT, snes_rtol = reltol,
56+
snes_atol = abstol, snes_max_it = maxiters)
57+
58+
PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0)))
59+
60+
if alg.autodiff === missing && prob.f.jac === nothing
61+
_jac! = nothing
62+
njac = Ref{Int}(-1)
63+
else
64+
autodiff = alg.autodiff === missing ? nothing : alg.autodiff
65+
if prob.u0 isa Number
66+
_jac! = NonlinearSolve.__construct_extension_jac(
67+
prob, alg, prob.u0, prob.u0; autodiff)
68+
J_init = zeros(T, 1, 1)
69+
else
70+
_jac!, J_init = NonlinearSolve.__construct_extension_jac(
71+
prob, alg, u0, resid; autodiff, initial_jacobian = Val(true))
72+
end
73+
74+
njac = Ref{Int}(0)
75+
76+
if J_init isa AbstractSparseMatrix
77+
PJ = PETSc.MatSeqAIJ(J_init)
78+
jac! = @closure (cx, J, _, user_ctx) -> begin
79+
njac[] += 1
80+
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
81+
if J isa PETSc.AbstractMat
82+
_jac!(user_ctx.jacobian, x)
83+
copyto!(J, user_ctx.jacobian)
84+
PETSc.assemble(J)
85+
else
86+
_jac!(J, x)
87+
end
88+
Base.finalize(x)
89+
return
90+
end
91+
PETSc.setjacobian!(snes, jac!, PJ, PJ)
92+
snes.user_ctx = (; jacobian = J_init)
93+
else
94+
PJ = PETSc.MatSeqDense(J_init)
95+
jac! = @closure (cx, J, _, user_ctx) -> begin
96+
njac[] += 1
97+
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
98+
_jac!(J, x)
99+
Base.finalize(x)
100+
J isa PETSc.AbstractMat && PETSc.assemble(J)
101+
return
102+
end
103+
PETSc.setjacobian!(snes, jac!, PJ, PJ)
104+
end
105+
end
106+
107+
res = PETSc.solve!(u0, snes)
108+
109+
_f!(resid, res)
110+
u_ = prob.u0 isa Number ? res[1] : res
111+
resid_ = prob.u0 isa Number ? resid[1] : resid
112+
113+
objective = maximum(abs, resid)
114+
# XXX: Return Code from PETSc
115+
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)
116+
return SciMLBase.build_solution(prob, alg, u_, resid_; retcode, original = snes,
117+
stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1))
118+
end
119+
120+
end

src/NonlinearSolve.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ include("algorithms/extension_algs.jl")
9999
include("utils.jl")
100100
include("default.jl")
101101

102+
const ALL_SOLVER_TYPES = [
103+
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
104+
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
105+
LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
106+
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
107+
CMINPACK, PETScSNES,
108+
NonlinearSolvePolyAlgorithm{:NLLS, <:Any}, NonlinearSolvePolyAlgorithm{:NLS, <:Any}
109+
]
110+
102111
include("internal/forward_diff.jl") # we need to define after the algorithms
103112

104113
@setup_workload begin
@@ -171,8 +180,9 @@ export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPoly
171180
FastShortcutNLLSPolyalg
172181

173182
# Extension Algorithms
174-
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
183+
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
175184
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
185+
export PETScSNES, CMINPACK
176186

177187
# Advanced Algorithms -- Without Bells and Whistles
178188
export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, GeneralizedDFSane

0 commit comments

Comments
 (0)