Skip to content

Commit c2a3ab4

Browse files
committed
fix: working prototype of GridapPETSc
1 parent 3a211ad commit c2a3ab4

File tree

3 files changed

+120
-5
lines changed

3 files changed

+120
-5
lines changed

ext/NonlinearSolveGridapPETScExt.jl

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,125 @@
11
module NonlinearSolveGridapPETScExt
22

3-
using Gridap: Gridap
3+
using Gridap: Gridap, Algebra
44
using GridapPETSc: GridapPETSc
55

66
using NonlinearSolveBase: NonlinearSolveBase
77
using NonlinearSolve: NonlinearSolve, GridapPETScSNES
8+
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
9+
10+
using ConcreteStructs: @concrete
11+
using FastClosures: @closure
12+
13+
@concrete struct NonlinearSolveOperator <: Algebra.NonlinearOperator
14+
f!
15+
jac!
16+
initial_guess_cache
17+
resid_prototype
18+
jacobian_prototype
19+
end
20+
21+
function Algebra.residual!(b::AbstractVector, op::NonlinearSolveOperator, x::AbstractVector)
22+
op.f!(b, x)
23+
end
24+
25+
function Algebra.jacobian!(
26+
A::AbstractMatrix, op::NonlinearSolveOperator, x::AbstractVector
27+
)
28+
op.jac!(A, x)
29+
end
30+
31+
function Algebra.zero_initial_guess(op::NonlinearSolveOperator)
32+
fill!(op.initial_guess_cache, 0)
33+
return op.initial_guess_cache
34+
end
35+
36+
function Algebra.allocate_residual(op::NonlinearSolveOperator, ::AbstractVector)
37+
fill!(op.resid_prototype, 0)
38+
return op.resid_prototype
39+
end
40+
41+
function Algebra.allocate_jacobian(op::NonlinearSolveOperator, ::AbstractVector)
42+
fill!(op.jacobian_prototype, 0)
43+
return op.jacobian_prototype
44+
end
45+
46+
# TODO: Later we should just wrap `Gridap` generally and pass in `PETSc` as the solver
47+
function SciMLBase.__solve(
48+
prob::NonlinearProblem, alg::GridapPETScSNES, args...;
49+
abstol = nothing, reltol = nothing,
50+
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
51+
show_trace::Val = Val(false), kwargs...
52+
)
53+
# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
54+
NonlinearSolveBase.assert_extension_supported_termination_condition(
55+
termination_condition, alg; abs_norm_supported = false
56+
)
57+
58+
f_wrapped!, u0, resid = NonlinearSolveBase.construct_extension_function_wrapper(
59+
prob; alias_u0
60+
)
61+
T = eltype(u0)
62+
63+
abstol = NonlinearSolveBase.get_tolerance(abstol, T)
64+
reltol = NonlinearSolveBase.get_tolerance(reltol, T)
65+
66+
nf = Ref{Int}(0)
67+
68+
f! = @closure (fx, x) -> begin
69+
nf[] += 1
70+
f_wrapped!(fx, x)
71+
return fx
72+
end
73+
74+
if prob.u0 isa Number
75+
jac! = NonlinearSolveBase.construct_extension_jac(
76+
prob, alg, prob.u0, prob.u0; alg.autodiff
77+
)
78+
J_init = zeros(T, 1, 1)
79+
else
80+
jac!, J_init = NonlinearSolveBase.construct_extension_jac(
81+
prob, alg, u0, resid; alg.autodiff, initial_jacobian = Val(true)
82+
)
83+
end
84+
85+
njac = Ref{Int}(-1)
86+
jac_fn! = @closure (J, x) -> begin
87+
njac[] += 1
88+
jac!(J, x)
89+
return J
90+
end
91+
92+
nop = NonlinearSolveOperator(f!, jac_fn!, u0, resid, J_init)
93+
94+
petsc_args = [
95+
"-snes_rtol", string(reltol), "-snes_atol", string(abstol),
96+
"-snes_max_it", string(maxiters)
97+
]
98+
for (k, v) in pairs(alg.snes_options)
99+
push!(petsc_args, "-$(k)")
100+
push!(petsc_args, string(v))
101+
end
102+
show_trace isa Val{true} && push!(petsc_args, "-snes_monitor")
103+
104+
@show petsc_args
105+
106+
# TODO: We can reuse the cache returned from this function
107+
sol_u = GridapPETSc.with(args = petsc_args) do
108+
sol_u = copy(u0)
109+
Algebra.solve!(sol_u, GridapPETSc.PETScNonlinearSolver(), nop)
110+
return sol_u
111+
end
112+
113+
f_wrapped!(resid, sol_u)
114+
u_res = prob.u0 isa Number ? sol_u[1] : sol_u
115+
resid_res = prob.u0 isa Number ? resid[1] : resid
116+
117+
objective = maximum(abs, resid)
118+
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)
119+
return SciMLBase.build_solution(
120+
prob, alg, u_res, resid_res;
121+
retcode, stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
122+
)
123+
end
8124

9125
end

ext/NonlinearSolvePETScExt.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function SciMLBase.__solve(
1818
show_trace::Val = Val(false), kwargs...
1919
)
2020
if !MPI.Initialized()
21-
@warn "MPI not initialized. Initializing MPI with MPI.Init()." maxlog = 1
21+
@warn "MPI not initialized. Initializing MPI with MPI.Init()." maxlog=1
2222
MPI.Init()
2323
end
2424

@@ -132,8 +132,7 @@ function SciMLBase.__solve(
132132
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)
133133
return SciMLBase.build_solution(
134134
prob, alg, u_res, resid_res;
135-
retcode, original = snes,
136-
stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
135+
retcode, stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
137136
)
138137
end
139138

src/NonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,6 @@ export NonlinearSolvePolyAlgorithm, FastShortcutNonlinearPolyalg, FastShortcutNL
121121
# Extension Algorithms
122122
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
123123
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
124-
export PETScSNES, CMINPACK
124+
export PETScSNES, GridapPETScSNES, CMINPACK
125125

126126
end

0 commit comments

Comments
 (0)