11module  NonlinearSolveGridapPETScExt
22
3- using  Gridap:  Gridap
3+ using  Gridap:  Gridap, Algebra 
44using  GridapPETSc:  GridapPETSc
55
66using  NonlinearSolveBase:  NonlinearSolveBase
77using  NonlinearSolve:  NonlinearSolve, GridapPETScSNES
8+ using  SciMLBase:  SciMLBase, NonlinearProblem, ReturnCode
9+ 
10+ using  ConcreteStructs:  @concrete 
11+ using  FastClosures:  @closure 
12+ 
13+ @concrete  struct  NonlinearSolveOperator <:  Algebra.NonlinearOperator 
14+     f!
15+     jac!
16+     initial_guess_cache
17+     resid_prototype
18+     jacobian_prototype
19+ end 
20+ 
21+ function  Algebra. residual! (b:: AbstractVector , op:: NonlinearSolveOperator , x:: AbstractVector )
22+     op. f! (b, x)
23+ end 
24+ 
25+ function  Algebra. jacobian! (
26+         A:: AbstractMatrix , op:: NonlinearSolveOperator , x:: AbstractVector 
27+ )
28+     op. jac! (A, x)
29+ end 
30+ 
31+ function  Algebra. zero_initial_guess (op:: NonlinearSolveOperator )
32+     fill! (op. initial_guess_cache, 0 )
33+     return  op. initial_guess_cache
34+ end 
35+ 
36+ function  Algebra. allocate_residual (op:: NonlinearSolveOperator , :: AbstractVector )
37+     fill! (op. resid_prototype, 0 )
38+     return  op. resid_prototype
39+ end 
40+ 
41+ function  Algebra. allocate_jacobian (op:: NonlinearSolveOperator , :: AbstractVector )
42+     fill! (op. jacobian_prototype, 0 )
43+     return  op. jacobian_prototype
44+ end 
45+ 
46+ #  TODO : Later we should just wrap `Gridap` generally and pass in `PETSc` as the solver
47+ function  SciMLBase. __solve (
48+         prob:: NonlinearProblem , alg:: GridapPETScSNES , args... ;
49+         abstol =  nothing , reltol =  nothing ,
50+         maxiters =  1000 , alias_u0:: Bool  =  false , termination_condition =  nothing ,
51+         show_trace:: Val  =  Val (false ), kwargs... 
52+ )
53+     #  XXX : https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
54+     NonlinearSolveBase. assert_extension_supported_termination_condition (
55+         termination_condition, alg; abs_norm_supported =  false 
56+     )
57+ 
58+     f_wrapped!, u0, resid =  NonlinearSolveBase. construct_extension_function_wrapper (
59+         prob; alias_u0
60+     )
61+     T =  eltype (u0)
62+ 
63+     abstol =  NonlinearSolveBase. get_tolerance (abstol, T)
64+     reltol =  NonlinearSolveBase. get_tolerance (reltol, T)
65+ 
66+     nf =  Ref {Int} (0 )
67+ 
68+     f! =  @closure  (fx, x) ->  begin 
69+         nf[] +=  1 
70+         f_wrapped! (fx, x)
71+         return  fx
72+     end 
73+ 
74+     if  prob. u0 isa  Number
75+         jac! =  NonlinearSolveBase. construct_extension_jac (
76+             prob, alg, prob. u0, prob. u0; alg. autodiff
77+         )
78+         J_init =  zeros (T, 1 , 1 )
79+     else 
80+         jac!, J_init =  NonlinearSolveBase. construct_extension_jac (
81+             prob, alg, u0, resid; alg. autodiff, initial_jacobian =  Val (true )
82+         )
83+     end 
84+ 
85+     njac =  Ref {Int} (- 1 )
86+     jac_fn! =  @closure  (J, x) ->  begin 
87+         njac[] +=  1 
88+         jac! (J, x)
89+         return  J
90+     end 
91+ 
92+     nop =  NonlinearSolveOperator (f!, jac_fn!, u0, resid, J_init)
93+ 
94+     petsc_args =  [
95+         " -snes_rtol" string (reltol), " -snes_atol" string (abstol),
96+         " -snes_max_it" string (maxiters)
97+     ]
98+     for  (k, v) in  pairs (alg. snes_options)
99+         push! (petsc_args, " -$(k) " 
100+         push! (petsc_args, string (v))
101+     end 
102+     show_trace isa  Val{true } &&  push! (petsc_args, " -snes_monitor" 
103+ 
104+     @show  petsc_args
105+ 
106+     #  TODO : We can reuse the cache returned from this function
107+     sol_u =  GridapPETSc. with (args =  petsc_args) do 
108+         sol_u =  copy (u0)
109+         Algebra. solve! (sol_u, GridapPETSc. PETScNonlinearSolver (), nop)
110+         return  sol_u
111+     end 
112+ 
113+     f_wrapped! (resid, sol_u)
114+     u_res =  prob. u0 isa  Number ?  sol_u[1 ] :  sol_u
115+     resid_res =  prob. u0 isa  Number ?  resid[1 ] :  resid
116+ 
117+     objective =  maximum (abs, resid)
118+     retcode =  ifelse (objective ≤  abstol, ReturnCode. Success, ReturnCode. Failure)
119+     return  SciMLBase. build_solution (
120+         prob, alg, u_res, resid_res;
121+         retcode, stats =  SciMLBase. NLStats (nf[], njac[], - 1 , - 1 , - 1 )
122+     )
123+ end 
8124
9125end 
0 commit comments