11module NonlinearSolveGridapPETScExt
22
3- using Gridap: Gridap
3+ using Gridap: Gridap, Algebra
44using GridapPETSc: GridapPETSc
55
66using NonlinearSolveBase: NonlinearSolveBase
77using NonlinearSolve: NonlinearSolve, GridapPETScSNES
8+ using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
9+
10+ using ConcreteStructs: @concrete
11+ using FastClosures: @closure
12+
13+ @concrete struct NonlinearSolveOperator <: Algebra.NonlinearOperator
14+ f!
15+ jac!
16+ initial_guess_cache
17+ resid_prototype
18+ jacobian_prototype
19+ end
20+
21+ function Algebra. residual! (b:: AbstractVector , op:: NonlinearSolveOperator , x:: AbstractVector )
22+ op. f! (b, x)
23+ end
24+
25+ function Algebra. jacobian! (
26+ A:: AbstractMatrix , op:: NonlinearSolveOperator , x:: AbstractVector
27+ )
28+ op. jac! (A, x)
29+ end
30+
31+ function Algebra. zero_initial_guess (op:: NonlinearSolveOperator )
32+ fill! (op. initial_guess_cache, 0 )
33+ return op. initial_guess_cache
34+ end
35+
36+ function Algebra. allocate_residual (op:: NonlinearSolveOperator , :: AbstractVector )
37+ fill! (op. resid_prototype, 0 )
38+ return op. resid_prototype
39+ end
40+
41+ function Algebra. allocate_jacobian (op:: NonlinearSolveOperator , :: AbstractVector )
42+ fill! (op. jacobian_prototype, 0 )
43+ return op. jacobian_prototype
44+ end
45+
46+ # TODO : Later we should just wrap `Gridap` generally and pass in `PETSc` as the solver
47+ function SciMLBase. __solve (
48+ prob:: NonlinearProblem , alg:: GridapPETScSNES , args... ;
49+ abstol = nothing , reltol = nothing ,
50+ maxiters = 1000 , alias_u0:: Bool = false , termination_condition = nothing ,
51+ show_trace:: Val = Val (false ), kwargs...
52+ )
53+ # XXX : https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
54+ NonlinearSolveBase. assert_extension_supported_termination_condition (
55+ termination_condition, alg; abs_norm_supported = false
56+ )
57+
58+ f_wrapped!, u0, resid = NonlinearSolveBase. construct_extension_function_wrapper (
59+ prob; alias_u0
60+ )
61+ T = eltype (u0)
62+
63+ abstol = NonlinearSolveBase. get_tolerance (abstol, T)
64+ reltol = NonlinearSolveBase. get_tolerance (reltol, T)
65+
66+ nf = Ref {Int} (0 )
67+
68+ f! = @closure (fx, x) -> begin
69+ nf[] += 1
70+ f_wrapped! (fx, x)
71+ return fx
72+ end
73+
74+ if prob. u0 isa Number
75+ jac! = NonlinearSolveBase. construct_extension_jac (
76+ prob, alg, prob. u0, prob. u0; alg. autodiff
77+ )
78+ J_init = zeros (T, 1 , 1 )
79+ else
80+ jac!, J_init = NonlinearSolveBase. construct_extension_jac (
81+ prob, alg, u0, resid; alg. autodiff, initial_jacobian = Val (true )
82+ )
83+ end
84+
85+ njac = Ref {Int} (- 1 )
86+ jac_fn! = @closure (J, x) -> begin
87+ njac[] += 1
88+ jac! (J, x)
89+ return J
90+ end
91+
92+ nop = NonlinearSolveOperator (f!, jac_fn!, u0, resid, J_init)
93+
94+ petsc_args = [
95+ " -snes_rtol" , string (reltol), " -snes_atol" , string (abstol),
96+ " -snes_max_it" , string (maxiters)
97+ ]
98+ for (k, v) in pairs (alg. snes_options)
99+ push! (petsc_args, " -$(k) " )
100+ push! (petsc_args, string (v))
101+ end
102+ show_trace isa Val{true } && push! (petsc_args, " -snes_monitor" )
103+
104+ @show petsc_args
105+
106+ # TODO : We can reuse the cache returned from this function
107+ sol_u = GridapPETSc. with (args = petsc_args) do
108+ sol_u = copy (u0)
109+ Algebra. solve! (sol_u, GridapPETSc. PETScNonlinearSolver (), nop)
110+ return sol_u
111+ end
112+
113+ f_wrapped! (resid, sol_u)
114+ u_res = prob. u0 isa Number ? sol_u[1 ] : sol_u
115+ resid_res = prob. u0 isa Number ? resid[1 ] : resid
116+
117+ objective = maximum (abs, resid)
118+ retcode = ifelse (objective ≤ abstol, ReturnCode. Success, ReturnCode. Failure)
119+ return SciMLBase. build_solution (
120+ prob, alg, u_res, resid_res;
121+ retcode, stats = SciMLBase. NLStats (nf[], njac[], - 1 , - 1 , - 1 )
122+ )
123+ end
8124
9125end
0 commit comments