Skip to content

Commit a7f80d4

Browse files
dhamIg-dolciJHopeCollinspbrubeck
authored andcommitted
Dham/abstract reduced functional (firedrakeproject#3941)
Co-authored-by: Daiane Iglesia Dolci <63597005+Ig-dolci@users.noreply.github.com> Co-authored-by: Josh Hope-Collins <joshua.hope-collins13@imperial.ac.uk> Co-authored-by: Pablo Brubeck <brubeck@protonmail.com>
1 parent 247552e commit a7f80d4

28 files changed

+794
-314
lines changed

demos/full_waveform_inversion/full_waveform_inversion.py.rst

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,12 @@ built over the ``my_ensemble.comm`` (spatial) communicator.
109109
dt = 0.03 # time step in seconds
110110
final_time = 0.6 # final time in seconds
111111
nx, ny = 15, 15
112+
ftol = 0.9 # optimisation tolerance
112113
else:
113114
dt = 0.002 # time step in seconds
114115
final_time = 1.0 # final time in seconds
115116
nx, ny = 80, 80
117+
ftol = 1e-2 # optimisation tolerance
116118

117119
mesh = UnitSquareMesh(nx, ny, comm=my_ensemble.comm)
118120

@@ -280,21 +282,28 @@ To have the step 4, we need first to tape the forward problem. That is done by c
280282

281283
We now instantiate :class:`~.EnsembleReducedFunctional`::
282284

283-
J_hat = EnsembleReducedFunctional(J_val, Control(c_guess), my_ensemble)
285+
J_hat = EnsembleReducedFunctional(J_val,
286+
Control(c_guess, riesz_map="l2"),
287+
my_ensemble)
284288

285289
which enables us to recompute :math:`J` and its gradient :math:`\nabla_{\mathtt{c\_guess}} J`,
286290
where the :math:`J_s` and its gradients :math:`\nabla_{\mathtt{c\_guess}} J_s` are computed in parallel
287291
based on the ``my_ensemble`` configuration.
288292

289293

290294
**Steps 4-6**: The instance of the :class:`~.EnsembleReducedFunctional`, named ``J_hat``,
291-
is then passed as an argument to the ``minimize`` function::
295+
is then passed as an argument to the ``minimize`` function. The default ``minimize`` function
296+
uses ``scipy.minimize``, and wraps the ``ReducedFunctional`` in a ``ReducedFunctionalNumPy``
297+
that handles transferring data between Firedrake and numpy data structures. However, because
298+
we have a custom ``ReducedFunctional``, we need to do this ourselves::
292299

293-
c_optimised = minimize(J_hat, method="L-BFGS-B", options={"disp": True, "maxiter": 1},
294-
bounds=(1.5, 2.0), derivative_options={"riesz_representation": 'l2'}
295-
)
300+
from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy
301+
Jnumpy = ReducedFunctionalNumPy(J_hat)
296302

297-
The ``minimize`` function executes the optimisation algorithm until the stopping criterion (``maxiter``) is met.
303+
c_optimised = minimize(Jnumpy, method="L-BFGS-B", options={"disp": True, "ftol": ftol},
304+
bounds=(1.5, 2.0))
305+
306+
The ``minimize`` function executes the optimisation algorithm until the stopping criterion (``ftol``) is met.
298307
For 20 iterations, the predicted velocity model is shown in the following figure.
299308

300309
.. image:: c_predicted.png
@@ -305,9 +314,7 @@ For 20 iterations, the predicted velocity model is shown in the following figure
305314
.. warning::
306315

307316
The ``minimize`` function uses the SciPy library for optimisation. However, for scenarios that require higher
308-
levels of spatial parallelism, you should assess whether SciPy is the most suitable option for your problem.
309-
SciPy's optimisation algorithm is not inner-product-aware. Therefore, we configure the options with
310-
``derivative_options={"riesz_representation": 'l2'}`` to account for this requirement.
317+
levels of spatial parallelism, you should assess whether SciPy is the most suitable option for your problem such as the pyadjoint's TAOSolver.
311318

312319
.. note::
313320

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
Shape optimization
2+
==================
3+
4+
Shape optimization is about modifying the shape of a domain :math:`\Omega` so
5+
that an objective function :math:`J(\Omega)` is minimized. In this demo, we
6+
consider an objective function constrained to a boundary value problem and
7+
implement a simple mesh-moving shape optimization strategy using Firedrake and
8+
pyadjoint. This tutorial was contributed by `Alberto Paganini
9+
<mailto:apaganini@le.ac.uk>`__ with support from `Ado Farsi
10+
<mailto:ado.farsi@imperial.ac.uk>`__ and `Mirko Ciceri
11+
<mailto:mc5823@ic.ac.uk>`__, and was written during the `ccp-dcm hackaton
12+
<https://ccp-dcm.github.io/exeter_hackathon>`__ at Dartington Hall.
13+
14+
Let
15+
16+
.. math::
17+
18+
J(\Omega) = \int_\Omega \big(u(\mathbf{x}) - u_t(\mathbf{x})\big)^2 \,\mathrm{d}\mathbf{x}\,.
19+
20+
measure the difference between a steady-state temperature profile
21+
:math:`u:\mathbb{R}^2\to\mathbb{R}` and a target steady-state temperature
22+
profile :math:`u_t:\mathbb{R}^2\to\mathbb{R}`. Specifically, the function
23+
:math:`u` is the solution to the steady-state heat equation
24+
25+
.. math::
26+
27+
-\Delta u = 4 \quad \text{in }\Omega\,, \qquad u = 0 \quad \text{on } \partial\Omega
28+
29+
30+
and the target temperature profile :math:`u_t` is
31+
32+
.. math::
33+
34+
u_t(x,y) = 1.21 - (x - 0.5)^2 - (y - 0.5)^2\,.
35+
36+
Beside the empty set, the domain that minimizes :math:`J(\Omega)` is a disc of
37+
radius :math:`1.1` centered at :math:`(0.5,0.5)`.
38+
39+
We can now proceed to set up the problem. We import firedrake and pyadjoint and
40+
choose an initial guess (in this case, a unit disc centred at the origin)::
41+
42+
from firedrake import *
43+
from firedrake.adjoint import *
44+
mesh = UnitDiskMesh(refinement_level=3)
45+
46+
Then, we :ref:`start annotating <adjoint-taping>` and turn the mesh coordinates into a control variable::
47+
48+
continue_annotation()
49+
Q = mesh.coordinates.function_space()
50+
dT = Function(Q)
51+
mesh.coordinates.assign(mesh.coordinates + dT)
52+
53+
We can now implement the target function::
54+
55+
x, y = SpatialCoordinate(mesh)
56+
u_t = Constant(1.21) - (x - Constant(0.5))**2 - (y - Constant(0.5))**2
57+
58+
solve the weak form of the boundary value problem::
59+
60+
V = FunctionSpace(mesh, "CG", 1)
61+
u = Function(V, name='state')
62+
v = TestFunction(V)
63+
F = (dot(grad(u), grad(v)) - 4 * v) * dx
64+
bcs = DirichletBC(V, Constant(0.), "on_boundary")
65+
solve(F == 0, u, bcs=bcs)
66+
67+
and evaluate the objective function::
68+
69+
J = assemble((u - u_t)**2*dx)
70+
71+
We now turn the objective function into a reduced function so that pyadjoint
72+
(and UFL shape differentiation capability) can automatically compute shape
73+
gradients, that is, directions of steepest ascent. We also set the relevant
74+
Riesz map for this problem::
75+
76+
Jred = ReducedFunctional(J, Control(dT, riesz_map="H1"))
77+
stop_annotating()
78+
79+
We now have all the ingredients to implement a basic steepest descent shape
80+
optimization algorithm with fixed step size.::
81+
82+
File = VTKFile("shape_iterates.pvd")
83+
for ii in range(30):
84+
print("J(ii =", ii, ") =", Jred(dT))
85+
File.write(mesh.coordinates)
86+
87+
# compute the gradient (steepest ascent)
88+
gradJ = Jred.derivative(apply_riesz=True)
89+
90+
# update domain
91+
dT -= 0.2*gradJ
92+
93+
File.write(mesh.coordinates)
94+
print("J(final) =", Jred(dT))
95+
96+
.. only:: html
97+
98+
.. container:: youtube
99+
100+
.. vimeo:: 1083822714?loop=1
101+
:width: 600px
102+
103+
104+
**Remark:** mesh-moving shape optimization can lead to mesh tangling, which
105+
invalidates finite element computations. For faster and more robust shape
106+
optimization, we recommend using Firedrake's shape optimization toolbox
107+
`Fireshape <https://github.com/fireshape/fireshape>`__.
108+
109+
A python script version of this demo can be found :demo:`here <shape_optimization.py>`.

firedrake/adjoint/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from firedrake.adjoint_utils import get_solve_blocks # noqa F401
3030

3131
from pyadjoint.verification import taylor_test, taylor_to_dict # noqa F401
32-
from pyadjoint.drivers import compute_gradient, compute_hessian # noqa F401
32+
from pyadjoint.drivers import compute_gradient, compute_derivative, compute_hessian # noqa F401
3333
from pyadjoint.adjfloat import AdjFloat # noqa F401
3434
from pyadjoint.control import Control # noqa F401
3535
from pyadjoint import IPOPTSolver, ROLSolver, MinimizationProblem, \

firedrake/adjoint/ensemble_reduced_functional.py

Lines changed: 72 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from pyadjoint import ReducedFunctional
1+
from pyadjoint.reduced_functional import AbstractReducedFunctional, ReducedFunctional
22
from pyadjoint.enlisting import Enlist
33
from pyop2.mpi import MPI
44

5-
import firedrake
5+
from firedrake.function import Function
6+
from firedrake.cofunction import Cofunction
67

78

8-
class EnsembleReducedFunctional(ReducedFunctional):
9+
class EnsembleReducedFunctional(AbstractReducedFunctional):
910
"""Enable solving simultaneously reduced functionals in parallel.
1011
1112
Consider a functional :math:`J` and its gradient :math:`\\dfrac{dJ}{dm}`,
@@ -34,7 +35,7 @@ class EnsembleReducedFunctional(ReducedFunctional):
3435
3536
Parameters
3637
----------
37-
J : pyadjoint.OverloadedType
38+
functional : pyadjoint.OverloadedType
3839
An instance of an OverloadedType, usually :class:`pyadjoint.AdjFloat`.
3940
This should be the functional that we want to reduce.
4041
control : pyadjoint.Control or list of pyadjoint.Control
@@ -86,28 +87,40 @@ class EnsembleReducedFunctional(ReducedFunctional):
8687
works, please refer to the `Firedrake manual
8788
<https://www.firedrakeproject.org/parallelism.html#ensemble-parallelism>`_.
8889
"""
89-
def __init__(self, J, control, ensemble, scatter_control=True,
90-
gather_functional=None, derivative_components=None,
91-
scale=1.0, tape=None, eval_cb_pre=lambda *args: None,
90+
def __init__(self, functional, control, ensemble, scatter_control=True,
91+
gather_functional=None,
92+
derivative_components=None,
93+
scale=1.0, tape=None,
94+
eval_cb_pre=lambda *args: None,
9295
eval_cb_post=lambda *args: None,
9396
derivative_cb_pre=lambda controls: controls,
9497
derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components,
95-
hessian_cb_pre=lambda *args: None, hessian_cb_post=lambda *args: None):
96-
super(EnsembleReducedFunctional, self).__init__(
97-
J, control, derivative_components=derivative_components,
98-
scale=scale, tape=tape, eval_cb_pre=eval_cb_pre,
99-
eval_cb_post=eval_cb_post, derivative_cb_pre=derivative_cb_pre,
98+
hessian_cb_pre=lambda *args: None,
99+
hessian_cb_post=lambda *args: None):
100+
self.local_reduced_functional = ReducedFunctional(
101+
functional, control,
102+
derivative_components=derivative_components,
103+
scale=scale, tape=tape,
104+
eval_cb_pre=eval_cb_pre,
105+
eval_cb_post=eval_cb_post,
106+
derivative_cb_pre=derivative_cb_pre,
100107
derivative_cb_post=derivative_cb_post,
101-
hessian_cb_pre=hessian_cb_pre, hessian_cb_post=hessian_cb_post)
108+
hessian_cb_pre=hessian_cb_pre,
109+
hessian_cb_post=hessian_cb_post
110+
)
102111

103112
self.ensemble = ensemble
104113
self.scatter_control = scatter_control
105114
self.gather_functional = gather_functional
106115

116+
@property
117+
def controls(self):
118+
return self.local_reduced_functional.controls
119+
107120
def _allgather_J(self, J):
108121
if isinstance(J, float):
109122
vals = self.ensemble.ensemble_comm.allgather(J)
110-
elif isinstance(J, firedrake.Function):
123+
elif isinstance(J, Function):
111124
# allgather not implemented in ensemble.py
112125
vals = []
113126
for i in range(self.ensemble.ensemble_comm.size):
@@ -134,30 +147,31 @@ def __call__(self, values):
134147
The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`.
135148
136149
"""
137-
local_functional = super(EnsembleReducedFunctional, self).__call__(values)
150+
local_functional = self.local_reduced_functional(values)
138151
ensemble_comm = self.ensemble.ensemble_comm
139152
if self.gather_functional:
140153
controls_g = self._allgather_J(local_functional)
141154
total_functional = self.gather_functional(controls_g)
142155
# if gather_functional is None then we do a sum
143156
elif isinstance(local_functional, float):
144157
total_functional = ensemble_comm.allreduce(sendobj=local_functional, op=MPI.SUM)
145-
elif isinstance(local_functional, firedrake.Function):
158+
elif isinstance(local_functional, Function):
146159
total_functional = type(local_functional)(local_functional.function_space())
147160
total_functional = self.ensemble.allreduce(local_functional, total_functional)
148161
else:
149162
raise NotImplementedError("This type of functional is not supported.")
150163
return total_functional
151164

152-
def derivative(self, adj_input=1.0, options=None):
165+
def derivative(self, adj_input=1.0, apply_riesz=False):
153166
"""Compute derivatives of a functional with respect to the control parameters.
154167
155168
Parameters
156169
----------
157170
adj_input : float
158171
The adjoint input.
159-
options : dict
160-
Additional options for the derivative computation.
172+
apply_riesz: bool
173+
If True, apply the Riesz map of each control in order to return
174+
a primal gradient rather than a derivative in the dual space.
161175
162176
Returns
163177
-------
@@ -171,29 +185,62 @@ def derivative(self, adj_input=1.0, options=None):
171185

172186
if self.gather_functional:
173187
dJg_dmg = self.gather_functional.derivative(adj_input=adj_input,
174-
options=options)
188+
apply_riesz=False)
175189
i = self.ensemble.ensemble_comm.rank
176190
adj_input = dJg_dmg[i]
177191

178-
dJdm_local = super(EnsembleReducedFunctional, self).derivative(adj_input=adj_input, options=options)
192+
dJdm_local = self.local_reduced_functional.derivative(adj_input=adj_input,
193+
apply_riesz=apply_riesz)
179194

180195
if self.scatter_control:
181196
dJdm_local = Enlist(dJdm_local)
182197
dJdm_total = []
183198

184199
for dJdm in dJdm_local:
185-
if not isinstance(dJdm, (firedrake.Function, float)):
186-
raise NotImplementedError("This type of gradient is not supported.")
200+
if not isinstance(dJdm, (Cofunction, Function, float)):
201+
raise NotImplementedError(
202+
f"Gradients of type {type(dJdm).__name__} are not supported.")
187203

188204
dJdm_total.append(
189205
self.ensemble.allreduce(dJdm, type(dJdm)(dJdm.function_space()))
190-
if isinstance(dJdm, firedrake.Function)
206+
if isinstance(dJdm, (Cofunction, Function))
191207
else self.ensemble.ensemble_comm.allreduce(sendobj=dJdm, op=MPI.SUM)
192208
)
193209
return dJdm_local.delist(dJdm_total)
194210
return dJdm_local
195211

196-
def hessian(self, m_dot, options=None):
212+
def tlm(self, m_dot):
213+
"""Return the action of the tangent linear model of the functional.
214+
215+
The tangent linear model is evaluated w.r.t. the control on a vector
216+
m_dot, around the last supplied value of the control.
217+
218+
Parameters
219+
----------
220+
m_dot : pyadjoint.OverloadedType
221+
The direction in which to compute the action of the tangent linear model.
222+
223+
Returns
224+
-------
225+
pyadjoint.OverloadedType: The action of the tangent linear model in the
226+
direction m_dot. Should be an instance of the same type as the functional.
227+
"""
228+
local_tlm = self.local_reduced_functional.tlm(m_dot)
229+
ensemble_comm = self.ensemble.ensemble_comm
230+
if self.gather_functional:
231+
mdot_g = self._allgather_J(local_tlm)
232+
total_tlm = self.gather_functional.tlm(mdot_g)
233+
# if gather_functional is None then we do a sum
234+
elif isinstance(local_tlm, float):
235+
total_tlm = ensemble_comm.allreduce(sendobj=local_tlm, op=MPI.SUM)
236+
elif isinstance(local_tlm, Function):
237+
total_tlm = type(local_tlm)(local_tlm.function_space())
238+
total_tlm = self.ensemble.allreduce(local_tlm, total_tlm)
239+
else:
240+
raise NotImplementedError("This type of functional is not supported.")
241+
return total_tlm
242+
243+
def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False):
197244
"""The Hessian is not yet implemented for ensemble reduced functional.
198245
199246
Raises:

0 commit comments

Comments
 (0)