Skip to content

Commit 8fb6f56

Browse files
authored
Test ReducedFunctional.tlm (#4448)
1 parent fa0a850 commit 8fb6f56

File tree

3 files changed

+40
-27
lines changed

3 files changed

+40
-27
lines changed

.github/workflows/core.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ jobs:
207207
--extra-index-url https://download.pytorch.org/whl/cpu \
208208
"$(echo ./firedrake-repo/dist/firedrake-*.tar.gz)[ci,docs]"
209209
: # TODO: Remove before merge
210-
pip install --verbose --editable git+https://github.com/dolfin-adjoint/pyadjoint.git@dham/abstract_reduced_functional#egg=pyadjoint-ad
210+
pip install --verbose --editable git+https://github.com/dolfin-adjoint/pyadjoint.git@JHopeCollins/tlm#egg=pyadjoint-ad
211211
212212
firedrake-clean
213213
pip list

firedrake/adjoint/ensemble_reduced_functional.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,38 @@ def derivative(self, adj_input=1.0, apply_riesz=False):
209209
return dJdm_local.delist(dJdm_total)
210210
return dJdm_local
211211

212-
def hessian(self, m_dot, apply_riesz=False):
212+
def tlm(self, m_dot):
213+
"""Return the action of the tangent linear model of the functional.
214+
215+
The tangent linear model is evaluated w.r.t. the control on a vector
216+
m_dot, around the last supplied value of the control.
217+
218+
Parameters
219+
----------
220+
m_dot : pyadjoint.OverloadedType
221+
The direction in which to compute the action of the tangent linear model.
222+
223+
Returns
224+
-------
225+
pyadjoint.OverloadedType: The action of the tangent linear model in the
226+
direction m_dot. Should be an instance of the same type as the functional.
227+
"""
228+
local_tlm = self.local_reduced_functional.tlm(m_dot)
229+
ensemble_comm = self.ensemble.ensemble_comm
230+
if self.gather_functional:
231+
mdot_g = self._allgather_J(local_tlm)
232+
total_tlm = self.gather_functional.tlm(mdot_g)
233+
# if gather_functional is None then we do a sum
234+
elif isinstance(local_tlm, float):
235+
total_tlm = ensemble_comm.allreduce(sendobj=local_tlm, op=MPI.SUM)
236+
elif isinstance(local_tlm, Function):
237+
total_tlm = type(local_tlm)(local_tlm.function_space())
238+
total_tlm = self.ensemble.allreduce(local_tlm, total_tlm)
239+
else:
240+
raise NotImplementedError("This type of functional is not supported.")
241+
return total_tlm
242+
243+
def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False):
213244
"""The Hessian is not yet implemented for ensemble reduced functional.
214245
215246
Raises:

tests/firedrake/adjoint/test_tlm.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def test_tlm_assemble():
5454
h = Function(V)
5555
h.vector()[:] = rand(h.dof_dset.size)
5656
g = f.copy(deepcopy=True)
57-
f.block_variable.tlm_value = h
58-
tape.evaluate_tlm()
59-
assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9)
57+
assert (taylor_test(Jhat, g, h, dJdm=Jhat.tlm(h)) > 1.9)
6058

6159

6260
@pytest.mark.skipcomplex
@@ -80,12 +78,7 @@ def test_tlm_bc():
8078
J = assemble(c ** 2 * u * dx)
8179
Jhat = ReducedFunctional(J, Control(c))
8280

83-
# Need to specify the domain for the constant as `ufl.action`, which requires `ufl.Constant`
84-
# to have a function space, will be applied on the tlm value.
85-
c.block_variable.tlm_value = Function(R, val=1)
86-
tape.evaluate_tlm()
87-
88-
assert (taylor_test(Jhat, c, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9)
81+
assert (taylor_test(Jhat, c, Function(R, val=1), dJdm=Jhat.tlm(Function(R, val=1))) > 1.9)
8982

9083

9184
@pytest.mark.skipcomplex
@@ -113,10 +106,8 @@ def test_tlm_func():
113106
h = Function(V)
114107
h.vector()[:] = rand(h.dof_dset.size)
115108
g = c.copy(deepcopy=True)
116-
c.block_variable.tlm_value = h
117-
tape.evaluate_tlm()
118109

119-
assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9)
110+
assert (taylor_test(Jhat, g, h, dJdm=Jhat.tlm(h)) > 1.9)
120111

121112

122113
@pytest.mark.parametrize("solve_type",
@@ -170,9 +161,7 @@ def test_time_dependent(solve_type):
170161
Jhat = ReducedFunctional(J, control)
171162
h = Function(V)
172163
h.vector()[:] = rand(h.dof_dset.size)
173-
u_1.tlm_value = h
174-
tape.evaluate_tlm()
175-
assert (taylor_test(Jhat, control.tape_value(), h, dJdm=J.block_variable.tlm_value) > 1.9)
164+
assert (taylor_test(Jhat, control.tape_value(), h, dJdm=Jhat.tlm(h)) > 1.9)
176165

177166

178167
@pytest.mark.skipcomplex
@@ -224,9 +213,7 @@ def Dt(u, u_, timestep):
224213
h = Function(V)
225214
h.vector()[:] = rand(h.dof_dset.size)
226215
g = ic.copy(deepcopy=True)
227-
ic.block_variable.tlm_value = h
228-
tape.evaluate_tlm()
229-
assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9)
216+
assert (taylor_test(Jhat, g, h, dJdm=Jhat.tlm(h)) > 1.9)
230217

231218

232219
@pytest.mark.skipcomplex
@@ -255,9 +242,7 @@ def test_projection():
255242
J = assemble(u_**2*dx)
256243
Jhat = ReducedFunctional(J, Control(k))
257244

258-
k.block_variable.tlm_value = Constant(1)
259-
tape.evaluate_tlm()
260-
assert (taylor_test(Jhat, k, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9)
245+
assert (taylor_test(Jhat, k, Function(R, val=1), dJdm=Jhat.tlm(Constant(1))) > 1.9)
261246

262247

263248
@pytest.mark.skipcomplex
@@ -268,7 +253,6 @@ def test_projection_function():
268253
V = FunctionSpace(mesh, "CG", 1)
269254

270255
bc = DirichletBC(V, Constant(1), "on_boundary")
271-
# g = Function(V)
272256
x, y = SpatialCoordinate(mesh)
273257
g = project(sin(x)*sin(y), V, annotate=False)
274258
expr = sin(g*x)
@@ -289,6 +273,4 @@ def test_projection_function():
289273
h = Function(V)
290274
h.vector()[:] = rand(h.dof_dset.size)
291275
m = g.copy(deepcopy=True)
292-
g.block_variable.tlm_value = h
293-
tape.evaluate_tlm()
294-
assert (taylor_test(Jhat, m, h, dJdm=J.block_variable.tlm_value) > 1.9)
276+
assert (taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.9)

0 commit comments

Comments
 (0)