Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/core.yml
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ jobs:
--extra-index-url https://download.pytorch.org/whl/cpu \
"$(echo ./firedrake-repo/dist/firedrake-*.tar.gz)[ci,docs]"
: # TODO: Remove before merge
pip install --verbose --editable git+https://github.com/dolfin-adjoint/pyadjoint.git@dham/abstract_reduced_functional#egg=pyadjoint-ad
pip install --verbose --editable git+https://github.com/dolfin-adjoint/pyadjoint.git@JHopeCollins/tlm#egg=pyadjoint-ad

firedrake-clean
pip list
Expand Down
16 changes: 16 additions & 0 deletions firedrake/adjoint/ensemble_reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,22 @@ def derivative(self, adj_input=1.0, apply_riesz=False):
return dJdm_local.delist(dJdm_total)
return dJdm_local

def tlm(self, m_dot):
local_tlm = self.local_reduced_functional.tlm(m_dot)
ensemble_comm = self.ensemble.ensemble_comm
if self.gather_functional:
mdot_g = self._allgather_J(local_tlm)
total_tlm = self.gather_functional.tlm(mdot_g)
# if gather_functional is None then we do a sum
elif isinstance(local_tlm, float):
total_tlm = ensemble_comm.allreduce(sendobj=local_tlm, op=MPI.SUM)
elif isinstance(local_tlm, Function):
total_tlm = type(local_tlm)(local_tlm.function_space())
total_tlm = self.ensemble.allreduce(local_tlm, total_tlm)
else:
raise NotImplementedError("This type of functional is not supported.")
return total_tlm

def hessian(self, m_dot, apply_riesz=False):
"""The Hessian is not yet implemented for ensemble reduced functional.

Expand Down
32 changes: 7 additions & 25 deletions tests/firedrake/adjoint/test_tlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def test_tlm_assemble():
h = Function(V)
h.vector()[:] = rand(h.dof_dset.size)
g = f.copy(deepcopy=True)
f.block_variable.tlm_value = h
tape.evaluate_tlm()
assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9)
assert (taylor_test(Jhat, g, h, dJdm=Jhat.tlm(h)) > 1.9)


@pytest.mark.skipcomplex
Expand All @@ -80,12 +78,7 @@ def test_tlm_bc():
J = assemble(c ** 2 * u * dx)
Jhat = ReducedFunctional(J, Control(c))

# Need to specify the domain for the constant as `ufl.action`, which requires `ufl.Constant`
# to have a function space, will be applied on the tlm value.
c.block_variable.tlm_value = Function(R, val=1)
tape.evaluate_tlm()

assert (taylor_test(Jhat, c, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9)
assert (taylor_test(Jhat, c, Function(R, val=1), dJdm=Jhat.tlm(Function(R, val=1))) > 1.9)


@pytest.mark.skipcomplex
Expand Down Expand Up @@ -113,10 +106,8 @@ def test_tlm_func():
h = Function(V)
h.vector()[:] = rand(h.dof_dset.size)
g = c.copy(deepcopy=True)
c.block_variable.tlm_value = h
tape.evaluate_tlm()

assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9)
assert (taylor_test(Jhat, g, h, dJdm=Jhat.tlm(h)) > 1.9)


@pytest.mark.parametrize("solve_type",
Expand Down Expand Up @@ -170,9 +161,7 @@ def test_time_dependent(solve_type):
Jhat = ReducedFunctional(J, control)
h = Function(V)
h.vector()[:] = rand(h.dof_dset.size)
u_1.tlm_value = h
tape.evaluate_tlm()
assert (taylor_test(Jhat, control.tape_value(), h, dJdm=J.block_variable.tlm_value) > 1.9)
assert (taylor_test(Jhat, control.tape_value(), h, dJdm=Jhat.tlm(h)) > 1.9)


@pytest.mark.skipcomplex
Expand Down Expand Up @@ -224,9 +213,7 @@ def Dt(u, u_, timestep):
h = Function(V)
h.vector()[:] = rand(h.dof_dset.size)
g = ic.copy(deepcopy=True)
ic.block_variable.tlm_value = h
tape.evaluate_tlm()
assert (taylor_test(Jhat, g, h, dJdm=J.block_variable.tlm_value) > 1.9)
assert (taylor_test(Jhat, g, h, dJdm=Jhat.tlm(h)) > 1.9)


@pytest.mark.skipcomplex
Expand Down Expand Up @@ -255,9 +242,7 @@ def test_projection():
J = assemble(u_**2*dx)
Jhat = ReducedFunctional(J, Control(k))

k.block_variable.tlm_value = Constant(1)
tape.evaluate_tlm()
assert (taylor_test(Jhat, k, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9)
assert (taylor_test(Jhat, k, Function(R, val=1), dJdm=Jhat.tlm(Constant(1))) > 1.9)


@pytest.mark.skipcomplex
Expand All @@ -268,7 +253,6 @@ def test_projection_function():
V = FunctionSpace(mesh, "CG", 1)

bc = DirichletBC(V, Constant(1), "on_boundary")
# g = Function(V)
x, y = SpatialCoordinate(mesh)
g = project(sin(x)*sin(y), V, annotate=False)
expr = sin(g*x)
Expand All @@ -289,6 +273,4 @@ def test_projection_function():
h = Function(V)
h.vector()[:] = rand(h.dof_dset.size)
m = g.copy(deepcopy=True)
g.block_variable.tlm_value = h
tape.evaluate_tlm()
assert (taylor_test(Jhat, m, h, dJdm=J.block_variable.tlm_value) > 1.9)
assert (taylor_test(Jhat, m, h, dJdm=Jhat.tlm(h)) > 1.9)
Loading