Skip to content

Commit 79900ff

Browse files
dhamIg-dolciJHopeCollinspbrubeck
authored
Dham/abstract reduced functional (#3941)
Co-authored-by: Daiane Iglesia Dolci <63597005+Ig-dolci@users.noreply.github.com> Co-authored-by: Josh Hope-Collins <joshua.hope-collins13@imperial.ac.uk> Co-authored-by: Pablo Brubeck <brubeck@protonmail.com>
1 parent 232c520 commit 79900ff

28 files changed

+612
-318
lines changed

demos/full_waveform_inversion/full_waveform_inversion.py.rst

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,12 @@ built over the ``my_ensemble.comm`` (spatial) communicator.
109109
dt = 0.03 # time step in seconds
110110
final_time = 0.6 # final time in seconds
111111
nx, ny = 15, 15
112+
ftol = 0.9 # optimisation tolerance
112113
else:
113114
dt = 0.002 # time step in seconds
114115
final_time = 1.0 # final time in seconds
115116
nx, ny = 80, 80
117+
ftol = 1e-2 # optimisation tolerance
116118

117119
mesh = UnitSquareMesh(nx, ny, comm=my_ensemble.comm)
118120

@@ -278,21 +280,28 @@ To have the step 4, we need first to tape the forward problem. That is done by c
278280

279281
We now instantiate :class:`~.EnsembleReducedFunctional`::
280282

281-
J_hat = EnsembleReducedFunctional(J_val, Control(c_guess), my_ensemble)
283+
J_hat = EnsembleReducedFunctional(J_val,
284+
Control(c_guess, riesz_map="l2"),
285+
my_ensemble)
282286

283287
which enables us to recompute :math:`J` and its gradient :math:`\nabla_{\mathtt{c\_guess}} J`,
284288
where the :math:`J_s` and its gradients :math:`\nabla_{\mathtt{c\_guess}} J_s` are computed in parallel
285289
based on the ``my_ensemble`` configuration.
286290

287291

288292
**Steps 4-6**: The instance of the :class:`~.EnsembleReducedFunctional`, named ``J_hat``,
289-
is then passed as an argument to the ``minimize`` function::
293+
is then passed as an argument to the ``minimize`` function. The default ``minimize`` function
294+
uses ``scipy.minimize``, and wraps the ``ReducedFunctional`` in a ``ReducedFunctionalNumPy``
295+
that handles transferring data between Firedrake and numpy data structures. However, because
296+
we have a custom ``ReducedFunctional``, we need to do this ourselves::
290297

291-
c_optimised = minimize(J_hat, method="L-BFGS-B", options={"disp": True, "maxiter": 1},
292-
bounds=(1.5, 2.0), derivative_options={"riesz_representation": 'l2'}
293-
)
298+
from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy
299+
Jnumpy = ReducedFunctionalNumPy(J_hat)
294300

295-
The ``minimize`` function executes the optimisation algorithm until the stopping criterion (``maxiter``) is met.
301+
c_optimised = minimize(Jnumpy, method="L-BFGS-B", options={"disp": True, "ftol": ftol},
302+
bounds=(1.5, 2.0))
303+
304+
The ``minimize`` function executes the optimisation algorithm until the stopping criterion (``ftol``) is met.
296305
For 20 iterations, the predicted velocity model is shown in the following figure.
297306

298307
.. image:: c_predicted.png
@@ -303,9 +312,7 @@ For 20 iterations, the predicted velocity model is shown in the following figure
303312
.. warning::
304313

305314
The ``minimize`` function uses the SciPy library for optimisation. However, for scenarios that require higher
306-
levels of spatial parallelism, you should assess whether SciPy is the most suitable option for your problem.
307-
SciPy's optimisation algorithm is not inner-product-aware. Therefore, we configure the options with
308-
``derivative_options={"riesz_representation": 'l2'}`` to account for this requirement.
315+
levels of spatial parallelism, you should assess whether SciPy is the most suitable option for your problem such as the pyadjoint's TAOSolver.
309316

310317
.. note::
311318

demos/shape_optimization/shape_optimization.py.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,10 @@ and evaluate the objective function::
7070

7171
We now turn the objective function into a reduced function so that pyadjoint
7272
(and UFL shape differentiation capability) can automatically compute shape
73-
gradients, that is, directions of steepest ascent::
73+
gradients, that is, directions of steepest ascent. We also set the relevant
74+
Riesz map for this problem::
7475

75-
Jred = ReducedFunctional(J, Control(dT))
76+
Jred = ReducedFunctional(J, Control(dT, riesz_map="H1"))
7677
stop_annotating()
7778

7879
We now have all the ingredients to implement a basic steepest descent shape
@@ -84,8 +85,7 @@ optimization algorithm with fixed step size.::
8485
File.write(mesh.coordinates)
8586

8687
# compute the gradient (steepest ascent)
87-
opts = {"riesz_representation": "H1"}
88-
gradJ = Jred.derivative(options=opts)
88+
gradJ = Jred.derivative(apply_riesz=True)
8989

9090
# update domain
9191
dT -= 0.2*gradJ

firedrake/adjoint/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from firedrake.adjoint_utils import get_solve_blocks # noqa F401
3030

3131
from pyadjoint.verification import taylor_test, taylor_to_dict # noqa F401
32-
from pyadjoint.drivers import compute_gradient, compute_hessian # noqa F401
32+
from pyadjoint.drivers import compute_gradient, compute_derivative, compute_hessian # noqa F401
3333
from pyadjoint.adjfloat import AdjFloat # noqa F401
3434
from pyadjoint.control import Control # noqa F401
3535
from pyadjoint import IPOPTSolver, ROLSolver, MinimizationProblem, \

firedrake/adjoint/ensemble_reduced_functional.py

Lines changed: 72 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from pyadjoint import ReducedFunctional
1+
from pyadjoint.reduced_functional import AbstractReducedFunctional, ReducedFunctional
22
from pyadjoint.enlisting import Enlist
33
from pyop2.mpi import MPI
44

5-
import firedrake
5+
from firedrake.function import Function
6+
from firedrake.cofunction import Cofunction
67

78

8-
class EnsembleReducedFunctional(ReducedFunctional):
9+
class EnsembleReducedFunctional(AbstractReducedFunctional):
910
"""Enable solving simultaneously reduced functionals in parallel.
1011
1112
Consider a functional :math:`J` and its gradient :math:`\\dfrac{dJ}{dm}`,
@@ -34,7 +35,7 @@ class EnsembleReducedFunctional(ReducedFunctional):
3435
3536
Parameters
3637
----------
37-
J : pyadjoint.OverloadedType
38+
functional : pyadjoint.OverloadedType
3839
An instance of an OverloadedType, usually :class:`pyadjoint.AdjFloat`.
3940
This should be the functional that we want to reduce.
4041
control : pyadjoint.Control or list of pyadjoint.Control
@@ -86,28 +87,40 @@ class EnsembleReducedFunctional(ReducedFunctional):
8687
works, please refer to the `Firedrake manual
8788
<https://www.firedrakeproject.org/parallelism.html#ensemble-parallelism>`_.
8889
"""
89-
def __init__(self, J, control, ensemble, scatter_control=True,
90-
gather_functional=None, derivative_components=None,
91-
scale=1.0, tape=None, eval_cb_pre=lambda *args: None,
90+
def __init__(self, functional, control, ensemble, scatter_control=True,
91+
gather_functional=None,
92+
derivative_components=None,
93+
scale=1.0, tape=None,
94+
eval_cb_pre=lambda *args: None,
9295
eval_cb_post=lambda *args: None,
9396
derivative_cb_pre=lambda controls: controls,
9497
derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components,
95-
hessian_cb_pre=lambda *args: None, hessian_cb_post=lambda *args: None):
96-
super(EnsembleReducedFunctional, self).__init__(
97-
J, control, derivative_components=derivative_components,
98-
scale=scale, tape=tape, eval_cb_pre=eval_cb_pre,
99-
eval_cb_post=eval_cb_post, derivative_cb_pre=derivative_cb_pre,
98+
hessian_cb_pre=lambda *args: None,
99+
hessian_cb_post=lambda *args: None):
100+
self.local_reduced_functional = ReducedFunctional(
101+
functional, control,
102+
derivative_components=derivative_components,
103+
scale=scale, tape=tape,
104+
eval_cb_pre=eval_cb_pre,
105+
eval_cb_post=eval_cb_post,
106+
derivative_cb_pre=derivative_cb_pre,
100107
derivative_cb_post=derivative_cb_post,
101-
hessian_cb_pre=hessian_cb_pre, hessian_cb_post=hessian_cb_post)
108+
hessian_cb_pre=hessian_cb_pre,
109+
hessian_cb_post=hessian_cb_post
110+
)
102111

103112
self.ensemble = ensemble
104113
self.scatter_control = scatter_control
105114
self.gather_functional = gather_functional
106115

116+
@property
117+
def controls(self):
118+
return self.local_reduced_functional.controls
119+
107120
def _allgather_J(self, J):
108121
if isinstance(J, float):
109122
vals = self.ensemble.ensemble_comm.allgather(J)
110-
elif isinstance(J, firedrake.Function):
123+
elif isinstance(J, Function):
111124
# allgather not implemented in ensemble.py
112125
vals = []
113126
for i in range(self.ensemble.ensemble_comm.size):
@@ -134,30 +147,31 @@ def __call__(self, values):
134147
The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`.
135148
136149
"""
137-
local_functional = super(EnsembleReducedFunctional, self).__call__(values)
150+
local_functional = self.local_reduced_functional(values)
138151
ensemble_comm = self.ensemble.ensemble_comm
139152
if self.gather_functional:
140153
controls_g = self._allgather_J(local_functional)
141154
total_functional = self.gather_functional(controls_g)
142155
# if gather_functional is None then we do a sum
143156
elif isinstance(local_functional, float):
144157
total_functional = ensemble_comm.allreduce(sendobj=local_functional, op=MPI.SUM)
145-
elif isinstance(local_functional, firedrake.Function):
158+
elif isinstance(local_functional, Function):
146159
total_functional = type(local_functional)(local_functional.function_space())
147160
total_functional = self.ensemble.allreduce(local_functional, total_functional)
148161
else:
149162
raise NotImplementedError("This type of functional is not supported.")
150163
return total_functional
151164

152-
def derivative(self, adj_input=1.0, options=None):
165+
def derivative(self, adj_input=1.0, apply_riesz=False):
153166
"""Compute derivatives of a functional with respect to the control parameters.
154167
155168
Parameters
156169
----------
157170
adj_input : float
158171
The adjoint input.
159-
options : dict
160-
Additional options for the derivative computation.
172+
apply_riesz: bool
173+
If True, apply the Riesz map of each control in order to return
174+
a primal gradient rather than a derivative in the dual space.
161175
162176
Returns
163177
-------
@@ -171,29 +185,62 @@ def derivative(self, adj_input=1.0, options=None):
171185

172186
if self.gather_functional:
173187
dJg_dmg = self.gather_functional.derivative(adj_input=adj_input,
174-
options=options)
188+
apply_riesz=False)
175189
i = self.ensemble.ensemble_comm.rank
176190
adj_input = dJg_dmg[i]
177191

178-
dJdm_local = super(EnsembleReducedFunctional, self).derivative(adj_input=adj_input, options=options)
192+
dJdm_local = self.local_reduced_functional.derivative(adj_input=adj_input,
193+
apply_riesz=apply_riesz)
179194

180195
if self.scatter_control:
181196
dJdm_local = Enlist(dJdm_local)
182197
dJdm_total = []
183198

184199
for dJdm in dJdm_local:
185-
if not isinstance(dJdm, (firedrake.Function, float)):
186-
raise NotImplementedError("This type of gradient is not supported.")
200+
if not isinstance(dJdm, (Cofunction, Function, float)):
201+
raise NotImplementedError(
202+
f"Gradients of type {type(dJdm).__name__} are not supported.")
187203

188204
dJdm_total.append(
189205
self.ensemble.allreduce(dJdm, type(dJdm)(dJdm.function_space()))
190-
if isinstance(dJdm, firedrake.Function)
206+
if isinstance(dJdm, (Cofunction, Function))
191207
else self.ensemble.ensemble_comm.allreduce(sendobj=dJdm, op=MPI.SUM)
192208
)
193209
return dJdm_local.delist(dJdm_total)
194210
return dJdm_local
195211

196-
def hessian(self, m_dot, options=None):
212+
def tlm(self, m_dot):
213+
"""Return the action of the tangent linear model of the functional.
214+
215+
The tangent linear model is evaluated w.r.t. the control on a vector
216+
m_dot, around the last supplied value of the control.
217+
218+
Parameters
219+
----------
220+
m_dot : pyadjoint.OverloadedType
221+
The direction in which to compute the action of the tangent linear model.
222+
223+
Returns
224+
-------
225+
pyadjoint.OverloadedType: The action of the tangent linear model in the
226+
direction m_dot. Should be an instance of the same type as the functional.
227+
"""
228+
local_tlm = self.local_reduced_functional.tlm(m_dot)
229+
ensemble_comm = self.ensemble.ensemble_comm
230+
if self.gather_functional:
231+
mdot_g = self._allgather_J(local_tlm)
232+
total_tlm = self.gather_functional.tlm(mdot_g)
233+
# if gather_functional is None then we do a sum
234+
elif isinstance(local_tlm, float):
235+
total_tlm = ensemble_comm.allreduce(sendobj=local_tlm, op=MPI.SUM)
236+
elif isinstance(local_tlm, Function):
237+
total_tlm = type(local_tlm)(local_tlm.function_space())
238+
total_tlm = self.ensemble.allreduce(local_tlm, total_tlm)
239+
else:
240+
raise NotImplementedError("This type of functional is not supported.")
241+
return total_tlm
242+
243+
def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False):
197244
"""The Hessian is not yet implemented for ensemble reduced functional.
198245
199246
Raises:

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,15 @@ def _init_solver_parameters(self, args, kwargs):
8686
self.assemble_kwargs = {}
8787

8888
def __str__(self):
89-
return "solve({} = {})".format(ufl2unicode(self.lhs),
90-
ufl2unicode(self.rhs))
89+
try:
90+
lhs_string = ufl2unicode(self.lhs)
91+
except AttributeError:
92+
lhs_string = str(self.lhs)
93+
try:
94+
rhs_string = ufl2unicode(self.rhs)
95+
except AttributeError:
96+
rhs_string = str(self.rhs)
97+
return "solve({} = {})".format(lhs_string, rhs_string)
9198

9299
def _create_F_form(self):
93100
# Process the equation forms, replacing values with checkpoints,
@@ -742,7 +749,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
742749
c = block_variable.output
743750
c_rep = block_variable.saved_output
744751

745-
if isinstance(c, firedrake.Function):
752+
if isinstance(c, (firedrake.Function, firedrake.Cofunction)):
746753
trial_function = firedrake.TrialFunction(c.function_space())
747754
elif isinstance(c, firedrake.Constant):
748755
mesh = F_form.ufl_domain()
@@ -779,7 +786,12 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
779786
replace_map[self.func] = self.get_outputs()[0].saved_output
780787
dFdm = replace(dFdm, replace_map)
781788

782-
dFdm = dFdm * adj_sol
789+
if isinstance(dFdm, firedrake.Argument):
790+
# Corner case. Should be fixed more permanently upstream in UFL.
791+
# See: https://github.com/FEniCS/ufl/issues/395
792+
dFdm = ufl.Action(dFdm, adj_sol)
793+
else:
794+
dFdm = dFdm * adj_sol
783795
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)
784796

785797
return dFdm

firedrake/adjoint_utils/constant.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from pyadjoint.adjfloat import AdjFloat
33
from pyadjoint.tape import get_working_tape, annotate_tape
44
from pyadjoint.overloaded_type import OverloadedType, create_overloaded_object
5-
from pyadjoint.reduced_functional_numpy import gather
65

76
from firedrake.functionspace import FunctionSpace
87
from firedrake.adjoint_utils.blocks import ConstantAssignBlock
@@ -58,15 +57,8 @@ def wrapper(self, *args, **kwargs):
5857

5958
return wrapper
6059

61-
def get_derivative(self, options={}):
62-
return self._ad_convert_type(self.adj_value, options=options)
63-
64-
def _ad_convert_type(self, value, options={}):
65-
if value is None:
66-
# TODO: Should the default be 0 constant here or return just None?
67-
return type(self)(numpy.zeros(self.ufl_shape))
68-
value = gather(value)
69-
return self._constant_from_values(value)
60+
def _ad_init_zero(self, dual=False):
61+
return type(self)(numpy.zeros(self.ufl_shape))
7062

7163
def _ad_function_space(self, mesh):
7264
element = self.ufl_element()

firedrake/adjoint_utils/ensemble_function.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ def _ad_dot(self, other, options=None):
5555
def _ad_convert_riesz(self, value, options=None):
5656
raise NotImplementedError
5757

58+
def _ad_init_zero(self, dual=False):
59+
from firedrake import EnsembleFunction, EnsembleCofunction
60+
if dual:
61+
return EnsembleCofunction(self.function_space().dual())
62+
else:
63+
return EnsembleFunction(self.function_space())
64+
5865
def _ad_create_checkpoint(self):
5966
if disk_checkpointing():
6067
raise NotImplementedError(

0 commit comments

Comments
 (0)