|
| 1 | +module NonlinearSolvePETScExt |
| 2 | + |
| 3 | +using FastClosures: @closure |
| 4 | +using MPI: MPI |
| 5 | +using NonlinearSolveBase: NonlinearSolveBase, get_tolerance |
| 6 | +using NonlinearSolve: NonlinearSolve, PETScSNES |
| 7 | +using PETSc: PETSc |
| 8 | +using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode |
| 9 | +using SparseArrays: AbstractSparseMatrix |
| 10 | + |
| 11 | +function SciMLBase.__solve( |
| 12 | + prob::NonlinearProblem, alg::PETScSNES, args...; abstol = nothing, reltol = nothing, |
| 13 | + maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing, |
| 14 | + show_trace::Val{ShT} = Val(false), kwargs...) where {ShT} |
| 15 | + # XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/ |
| 16 | + termination_condition === nothing || |
| 17 | + error("`PETScSNES` does not support termination conditions!") |
| 18 | + |
| 19 | + _f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0) |
| 20 | + T = eltype(prob.u0) |
| 21 | + @assert T ∈ PETSc.scalar_types |
| 22 | + |
| 23 | + if alg.petsclib === missing |
| 24 | + petsclibidx = findfirst(PETSc.petsclibs) do petsclib |
| 25 | + petsclib isa PETSc.PetscLibType{T} |
| 26 | + end |
| 27 | + |
| 28 | + if petsclibidx === nothing |
| 29 | + error("No compatible PETSc library found for element type $(T). Pass in a \ |
| 30 | + custom `petsclib` via `PETScSNES(; petsclib = <petsclib>, ....)`.") |
| 31 | + end |
| 32 | + petsclib = PETSc.petsclibs[petsclibidx] |
| 33 | + else |
| 34 | + petsclib = alg.petsclib |
| 35 | + end |
| 36 | + PETSc.initialized(petsclib) || PETSc.initialize(petsclib) |
| 37 | + |
| 38 | + abstol = get_tolerance(abstol, T) |
| 39 | + reltol = get_tolerance(reltol, T) |
| 40 | + |
| 41 | + nf = Ref{Int}(0) |
| 42 | + |
| 43 | + f! = @closure (cfx, cx, user_ctx) -> begin |
| 44 | + nf[] += 1 |
| 45 | + fx = cfx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cfx; read = false) : cfx |
| 46 | + x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx |
| 47 | + _f!(fx, x) |
| 48 | + Base.finalize(fx) |
| 49 | + Base.finalize(x) |
| 50 | + return |
| 51 | + end |
| 52 | + |
| 53 | + snes = PETSc.SNES{T}(petsclib, |
| 54 | + alg.mpi_comm === missing ? MPI.COMM_SELF : alg.mpi_comm; |
| 55 | + alg.snes_options..., snes_monitor = ShT, snes_rtol = reltol, |
| 56 | + snes_atol = abstol, snes_max_it = maxiters) |
| 57 | + |
| 58 | + PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0))) |
| 59 | + |
| 60 | + if alg.autodiff === missing && prob.f.jac === nothing |
| 61 | + _jac! = nothing |
| 62 | + njac = Ref{Int}(-1) |
| 63 | + else |
| 64 | + autodiff = alg.autodiff === missing ? nothing : alg.autodiff |
| 65 | + if prob.u0 isa Number |
| 66 | + _jac! = NonlinearSolve.__construct_extension_jac( |
| 67 | + prob, alg, prob.u0, prob.u0; autodiff) |
| 68 | + J_init = zeros(T, 1, 1) |
| 69 | + else |
| 70 | + _jac!, J_init = NonlinearSolve.__construct_extension_jac( |
| 71 | + prob, alg, u0, resid; autodiff, initial_jacobian = Val(true)) |
| 72 | + end |
| 73 | + |
| 74 | + njac = Ref{Int}(0) |
| 75 | + |
| 76 | + if J_init isa AbstractSparseMatrix |
| 77 | + PJ = PETSc.MatSeqAIJ(J_init) |
| 78 | + jac! = @closure (cx, J, _, user_ctx) -> begin |
| 79 | + njac[] += 1 |
| 80 | + x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx |
| 81 | + if J isa PETSc.AbstractMat |
| 82 | + _jac!(user_ctx.jacobian, x) |
| 83 | + copyto!(J, user_ctx.jacobian) |
| 84 | + PETSc.assemble(J) |
| 85 | + else |
| 86 | + _jac!(J, x) |
| 87 | + end |
| 88 | + Base.finalize(x) |
| 89 | + return |
| 90 | + end |
| 91 | + PETSc.setjacobian!(snes, jac!, PJ, PJ) |
| 92 | + snes.user_ctx = (; jacobian = J_init) |
| 93 | + else |
| 94 | + PJ = PETSc.MatSeqDense(J_init) |
| 95 | + jac! = @closure (cx, J, _, user_ctx) -> begin |
| 96 | + njac[] += 1 |
| 97 | + x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx |
| 98 | + _jac!(J, x) |
| 99 | + Base.finalize(x) |
| 100 | + J isa PETSc.AbstractMat && PETSc.assemble(J) |
| 101 | + return |
| 102 | + end |
| 103 | + PETSc.setjacobian!(snes, jac!, PJ, PJ) |
| 104 | + end |
| 105 | + end |
| 106 | + |
| 107 | + res = PETSc.solve!(u0, snes) |
| 108 | + |
| 109 | + _f!(resid, res) |
| 110 | + u_ = prob.u0 isa Number ? res[1] : res |
| 111 | + resid_ = prob.u0 isa Number ? resid[1] : resid |
| 112 | + |
| 113 | + objective = maximum(abs, resid) |
| 114 | + # XXX: Return Code from PETSc |
| 115 | + retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure) |
| 116 | + return SciMLBase.build_solution(prob, alg, u_, resid_; retcode, original = snes, |
| 117 | + stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)) |
| 118 | +end |
| 119 | + |
| 120 | +end |
0 commit comments