Skip to content

Commit 927744d

Browse files
authored
Fix Hessian calculation for NonlinearVariationalSolver block (#4641)
* NLVS tape block stores adj_sol per block not per solver * allow pyadjoint to clean up the nlvs cached adjoint state
1 parent c7da9c3 commit 927744d

File tree

3 files changed

+40
-38
lines changed

3 files changed

+40
-38
lines changed

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ def _should_compute_boundary_adjoint(self, relevant_dependencies):
163163
def adj_sol(self):
164164
return self.adj_state
165165

166+
@adj_sol.setter
167+
def adj_sol(self, value):
168+
if self.adj_state is None:
169+
self.adj_state = value.copy(deepcopy=True)
170+
else:
171+
self.adj_state.assign(value)
172+
166173
def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
167174
fwd_block_variable = self.get_outputs()[0]
168175
u = fwd_block_variable.output
@@ -187,7 +194,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
187194
adj_sol, adj_sol_bdy = self._assemble_and_solve_adj_eq(
188195
dFdu_form, dJdu, compute_bdy
189196
)
190-
self.adj_state = adj_sol
197+
self.adj_sol = adj_sol
191198
if self.adj_cb is not None:
192199
self.adj_cb(adj_sol)
193200
if self.adj_bdy_cb is not None and compute_bdy:
@@ -408,7 +415,7 @@ def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
408415
firedrake.derivative(dFdu_form, fwd_block_variable.saved_output,
409416
tlm_output))
410417

411-
adj_sol = self.adj_state
418+
adj_sol = self.adj_sol
412419
if adj_sol is None:
413420
raise RuntimeError("Hessian computation was run before adjoint.")
414421
bdy = self._should_compute_boundary_adjoint(relevant_dependencies)
@@ -726,7 +733,7 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
726733
relevant_dependencies
727734
)
728735
adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy)
729-
self.adj_state = adj_sol
736+
self.adj_sol = adj_sol
730737
if self.adj_cb is not None:
731738
self.adj_cb(adj_sol)
732739
if self.adj_bdy_cb is not None and compute_bdy:

tests/firedrake/adjoint/test_hessian.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_simple_solve(rg):
3737
mesh = IntervalMesh(10, 0, 1)
3838
V = FunctionSpace(mesh, "Lagrange", 1)
3939

40-
f = Function(V).assign(2)
40+
f = Function(V).assign(2.)
4141

4242
u = TrialFunction(V)
4343
v = TestFunction(V)
@@ -76,10 +76,10 @@ def test_mixed_derivatives(rg):
7676
mesh = IntervalMesh(10, 0, 1)
7777
V = FunctionSpace(mesh, "Lagrange", 1)
7878

79-
f = Function(V).assign(2)
79+
f = Function(V).assign(2.)
8080
control_f = Control(f)
8181

82-
g = Function(V).assign(3)
82+
g = Function(V).assign(3.)
8383
control_g = Control(g)
8484

8585
u = TrialFunction(V)
@@ -126,7 +126,7 @@ def test_function(rg):
126126
R = FunctionSpace(mesh, "R", 0)
127127
c = Function(R, val=4)
128128
control_c = Control(c)
129-
f = Function(V).assign(3)
129+
f = Function(V).assign(3.)
130130
control_f = Control(f)
131131

132132
u = Function(V)
@@ -139,14 +139,14 @@ def test_function(rg):
139139
J = assemble(c ** 2 * u ** 2 * dx)
140140

141141
Jhat = ReducedFunctional(J, [control_c, control_f])
142-
dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True)
142+
dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True)
143143

144144
# Step direction for derivatives and convergence test
145145
h_c = Function(R, val=1.0)
146146
h_f = rg.uniform(V, 0, 10)
147147

148148
# Total derivative
149-
dJdc, dJdf = compute_gradient(J, [control_c, control_f], apply_riesz=True)
149+
dJdc, dJdf = compute_derivative(J, [control_c, control_f], apply_riesz=True)
150150
dJdm = assemble(dJdc * h_c * dx + dJdf * h_f * dx)
151151

152152
# Hessian
@@ -163,7 +163,7 @@ def test_nonlinear(rg):
163163
mesh = UnitSquareMesh(10, 10)
164164
V = FunctionSpace(mesh, "Lagrange", 1)
165165
R = FunctionSpace(mesh, "R", 0)
166-
f = Function(V).assign(5)
166+
f = Function(V).assign(5.)
167167

168168
u = Function(V)
169169
v = TestFunction(V)
@@ -201,11 +201,11 @@ def test_dirichlet(rg):
201201
mesh = UnitSquareMesh(10, 10)
202202
V = FunctionSpace(mesh, "Lagrange", 1)
203203

204-
f = Function(V).assign(30)
204+
f = Function(V).assign(30.)
205205

206206
u = Function(V)
207207
v = TestFunction(V)
208-
c = Function(V).assign(1)
208+
c = Function(V).assign(1.)
209209
bc = DirichletBC(V, c, "on_boundary")
210210

211211
F = inner(grad(u), grad(v)) * dx + u**4*v*dx - f**2 * v * dx
@@ -249,24 +249,25 @@ def Dt(u, u_, timestep):
249249
pr = project(sin(2*pi*x), V, annotate=False)
250250
ic = Function(V).assign(pr)
251251

252-
u_ = Function(V)
253-
u = Function(V)
252+
u_ = Function(V).assign(ic)
253+
u = Function(V).assign(ic)
254254
v = TestFunction(V)
255255

256256
nu = Constant(0.0001)
257257

258-
timestep = Constant(1.0/n)
258+
dt = 0.01
259+
nt = 20
259260

260261
params = {
261262
'snes_rtol': 1e-10,
262263
'ksp_type': 'preonly',
263264
'pc_type': 'lu',
264265
}
265266

266-
F = (Dt(u, ic, timestep)*v
267+
F = (Dt(u, u_, dt)*v
267268
+ u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx
269+
268270
bc = DirichletBC(V, 0.0, "on_boundary")
269-
t = 0.0
270271

271272
if solve_type == "nlvs":
272273
use_nlvs = True
@@ -285,21 +286,14 @@ def Dt(u, u_, timestep):
285286
else:
286287
solve(F == 0, u, bc, solver_parameters=params)
287288
u_.assign(u)
288-
t += float(timestep)
289289

290-
F = (Dt(u, u_, timestep)*v
291-
+ u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx
292-
293-
end = 0.2
294-
while (t <= end):
290+
for _ in range(nt):
295291
if use_nlvs:
296292
solver.solve()
297293
else:
298294
solve(F == 0, u, bc, solver_parameters=params)
299295
u_.assign(u)
300296

301-
t += float(timestep)
302-
303297
J = assemble(u_*u_*dx + ic*ic*dx)
304298

305299
Jhat = ReducedFunctional(J, Control(ic))

tests/firedrake/regression/test_adjoint_operators.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_interpolate_vector_valued():
124124
J = assemble(inner(f, g)*u**2*dx)
125125
rf = ReducedFunctional(J, Control(f))
126126

127-
h = Function(V1).assign(1)
127+
h = Function(V1).assign(1.)
128128
assert taylor_test(rf, f, h) > 1.9
129129

130130

@@ -144,7 +144,7 @@ def test_interpolate_tlm():
144144
J = assemble(inner(f, g)*u**2*dx)
145145
rf = ReducedFunctional(J, Control(f))
146146

147-
h = Function(V1).assign(1)
147+
h = Function(V1).assign(1.)
148148
f.block_variable.tlm_value = h
149149

150150
tape = get_working_tape()
@@ -259,7 +259,7 @@ def test_interpolate_to_function_space_cross_mesh():
259259
mesh_src = UnitSquareMesh(2, 2)
260260
mesh_dest = UnitSquareMesh(3, 3, quadrilateral=True)
261261
V = FunctionSpace(mesh_src, "CG", 1)
262-
W = FunctionSpace(mesh_dest, "DG", 1)
262+
W = FunctionSpace(mesh_dest, "DQ", 1)
263263
R = FunctionSpace(mesh_src, "R", 0)
264264
u = Function(V)
265265

@@ -290,7 +290,7 @@ def test_interpolate_hessian_linear_expr(rg):
290290
# space h and perterbation direction g.
291291
W = FunctionSpace(mesh, "Lagrange", 2)
292292
R = FunctionSpace(mesh, "R", 0)
293-
f = Function(W).assign(5)
293+
f = Function(W).assign(5.)
294294
# Note that we interpolate from a linear expression
295295
expr_interped = Function(V).interpolate(2*f)
296296

@@ -345,7 +345,7 @@ def test_interpolate_hessian_nonlinear_expr(rg):
345345
# space h and perterbation direction g.
346346
W = FunctionSpace(mesh, "Lagrange", 2)
347347
R = FunctionSpace(mesh, "R", 0)
348-
f = Function(W).assign(5)
348+
f = Function(W).assign(5.)
349349
# Note that we interpolate from a nonlinear expression
350350
expr_interped = Function(V).interpolate(f**2)
351351

@@ -400,8 +400,8 @@ def test_interpolate_hessian_nonlinear_expr_multi(rg):
400400
# space h and perterbation direction g.
401401
W = FunctionSpace(mesh, "Lagrange", 2)
402402
R = FunctionSpace(mesh, "R", 0)
403-
f = Function(W).assign(5)
404-
w = Function(W).assign(4)
403+
f = Function(W).assign(5.)
404+
w = Function(W).assign(4.)
405405
c = Function(R, val=2.0)
406406
# Note that we interpolate from a nonlinear expression with 3 coefficients
407407
expr_interped = Function(V).interpolate(f**2+w**2+c**2)
@@ -460,8 +460,8 @@ def test_interpolate_hessian_nonlinear_expr_multi_cross_mesh(rg):
460460
mesh_src = UnitSquareMesh(11, 11)
461461
R_src = FunctionSpace(mesh_src, "R", 0)
462462
W = FunctionSpace(mesh_src, "Lagrange", 2)
463-
f = Function(W).assign(5)
464-
w = Function(W).assign(4)
463+
f = Function(W).assign(5.)
464+
w = Function(W).assign(4.)
465465
c = Function(R_src, val=2.0)
466466
# Note that we interpolate from a nonlinear expression with 3 coefficients
467467
expr_interped = Function(V).interpolate(f**2+w**2+c**2)
@@ -1035,8 +1035,9 @@ def u_analytical(x, a, b):
10351035
tape = get_working_tape()
10361036
# Check the checkpointed boundary conditions are not updating the
10371037
# user-defined boundary conditions ``bc_left`` and ``bc_right``.
1038-
assert isinstance(tape._blocks[0], DirichletBCBlock) and \
1039-
tape._blocks[0]._outputs[0].checkpoint.checkpoint is not bc_left._original_arg
1038+
assert isinstance(tape._blocks[0], DirichletBCBlock)
1039+
assert tape._blocks[0]._outputs[0].checkpoint.checkpoint is not bc_left._original_arg
1040+
10401041
# tape._blocks[1] is the DirichletBC block for the right boundary
1041-
assert isinstance(tape._blocks[1], DirichletBCBlock) and \
1042-
tape._blocks[1]._outputs[0].checkpoint.checkpoint is not bc_right._original_arg
1042+
assert isinstance(tape._blocks[1], DirichletBCBlock)
1043+
assert tape._blocks[1]._outputs[0].checkpoint.checkpoint is not bc_right._original_arg

0 commit comments

Comments
 (0)