Skip to content
Merged
Show file tree
Hide file tree
Changes from 91 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
109b39d
test for correct annotation of cofunction assign
dham Dec 17, 2024
aa010bf
DROP BEFORE MERGE
dham Jul 8, 2024
e56f5ff
beginnings of doing inner product right
dham Jul 1, 2024
6fc8cdb
rieszmap
dham Jul 1, 2024
c373c31
missing import
dham Jul 1, 2024
90cd52b
fix l2 for mixed spaces
dham Jul 4, 2024
fc45c48
fix reisz map for primal case
dham Jul 4, 2024
1c16bb5
riesz_representation
dham Jul 4, 2024
42f974c
Update firedrake/adjoint_utils/blocks/solving.py
dham Jul 4, 2024
8b93fe4
rearrange to pass documentation
dham Jul 4, 2024
85a8267
fix solver options
dham Jul 4, 2024
6768c9f
make solve string more robust
dham Jul 6, 2024
f98e837
docstrings
dham Jul 8, 2024
1d8e58a
remove merged branch reference
dham Dec 17, 2024
cc402aa
lint
dham Dec 17, 2024
cde10ed
Remove unused solver options
dham Dec 17, 2024
8f86232
wind back API changes
dham Dec 18, 2024
720ad1d
test for correct annotation of cofunction assign
dham Dec 18, 2024
674e4bb
overloadedtype interface changes
dham Dec 18, 2024
de14bf3
interface update in test
dham Dec 18, 2024
cf57ae0
fix GlobalDataSet._apply_local_global_filter
dham Dec 18, 2024
55e1194
DROP BEFORE MERGE
dham Dec 18, 2024
112750e
lint
dham Dec 18, 2024
c408ee3
typo
dham Dec 29, 2024
da47938
fix interface in more tests
dham Jan 2, 2025
7d03656
fix pytorch interface
dham Jan 2, 2025
2478392
Merge branch 'master' into dham/abstract_reduced_functional
dham Jan 29, 2025
7af9880
fix dual spec
dham Feb 7, 2025
8e4ae06
Update pyadjoint interface
dham Feb 8, 2025
b92e400
apply riesz where needed
dham Feb 8, 2025
5fb8675
Merge branch 'master' into dham/abstract_reduced_functional
dham Mar 25, 2025
a96ee52
fix pyadjoint branch install command
JHopeCollins Mar 25, 2025
a40f420
AbstractReducedFunctional updates for EnsembleReducedFunctional
JHopeCollins Mar 25, 2025
b5afaca
fix pyadjoint pip install command??
JHopeCollins Mar 25, 2025
da694b0
change test derivative to old primal return value
JHopeCollins Mar 25, 2025
5adabb3
apply Riesz when needed
dham Mar 25, 2025
78c2ab3
demo doc typo
JHopeCollins Mar 25, 2025
7aac8b3
Merge branch 'dham/abstract_reduced_functional' of github.com:firedra…
JHopeCollins Mar 25, 2025
75247c6
don't restrict without BCs
dham Mar 25, 2025
73f50fc
remove old ad methods
JHopeCollins Mar 25, 2025
b2fee42
apply_riesz in adjoint tests
JHopeCollins Mar 25, 2025
859bc01
Merge branch 'dham/abstract_reduced_functional' of github.com:firedra…
JHopeCollins Mar 25, 2025
a19b280
set test back to parallel
dham Mar 26, 2025
565b9ba
fix fwi demo class reference
JHopeCollins Mar 26, 2025
6ea43a4
more apply_riesz in adjoint tests
JHopeCollins Mar 26, 2025
2dab55a
move riesz map to control
dham Mar 26, 2025
4b027e8
ensemble rf docs fix?
JHopeCollins Mar 26, 2025
f10c113
rf riesz options in Jax compat
JHopeCollins Mar 26, 2025
506bc4f
restore riesz map comparison to optimisation test
JHopeCollins Mar 26, 2025
83dacc4
Fix bcs
pbrubeck Mar 26, 2025
0ad26ba
remove spurious split method
dham Mar 26, 2025
12df6ac
make restrict a parameter to RieszMap
dham Mar 26, 2025
7bbf2cc
DROP BEFORE MERGE
dham Mar 26, 2025
2483ef4
Update tests/firedrake/adjoint/test_ensemble_reduced_functional.py
dham Mar 26, 2025
c626576
Update .github/workflows/docs.yml
dham Mar 26, 2025
0f8d13c
Test pre_apply_bcs in adjoint
pbrubeck Mar 26, 2025
a4e7e90
Merge branch 'dham/abstract_reduced_functional' of github.com:firedra…
pbrubeck Mar 26, 2025
6ffee32
compute_gradient -> compute_derivative
JHopeCollins Mar 26, 2025
45d4818
Merge branch 'dham/abstract_reduced_functional' of https://github.com…
JHopeCollins Mar 26, 2025
112892e
random echo issue
dham Mar 26, 2025
59940d8
Update tests/firedrake/output/test_adjoint_disk_checkpointing.py
dham Mar 26, 2025
4a3c02e
Update firedrake/adjoint/__init__.py
dham Mar 26, 2025
2a3dc67
Merge remote-tracking branch 'origin/master' into dham/abstract_reduc…
dham Apr 22, 2025
4e7fa49
Merge branch 'master' into dham/abstract_reduced_functional
JHopeCollins May 7, 2025
caf0f53
REVERT BEFORE MERGE: checkout pyadjoint branch
JHopeCollins May 7, 2025
1eb1864
Merge branch 'master' into dham/abstract_reduced_functional
JHopeCollins May 9, 2025
e81ea51
Merge remote-tracking branch 'origin/master' into dham/abstract_reduc…
dham Jun 27, 2025
22e6730
loosen unnecessarily tight tolerances on tao optimisation tests
JHopeCollins Jun 30, 2025
57271a7
Merge remote-tracking branch 'origin/master' into dham/abstract_reduc…
dham Jul 3, 2025
0de241e
Inner product jacobian is constant.
dham Jul 4, 2025
bd7aa77
Apply suggestions from code review
dham Jul 9, 2025
b071ea1
Github issue for the UFL problem
dham Jul 16, 2025
f3f5d11
lint
dham Jul 16, 2025
d7e30f1
Merge remote-tracking branch 'origin/master' into dham/abstract_reduc…
dham Jul 16, 2025
9cb41c1
Merge branch 'dham/abstract_reduced_functional' of ssh://github.com/f…
JHopeCollins Jul 17, 2025
fa0a850
Merge branch 'master' into dham/abstract_reduced_functional
JHopeCollins Jul 17, 2025
8fb6f56
Test `ReducedFunctional.tlm` (#4448)
JHopeCollins Jul 23, 2025
59c3803
Merge branch 'master' into dham/abstract_reduced_functional
JHopeCollins Jul 25, 2025
3b77dcf
Merge branch 'master' into dham/abstract_reduced_functional
JHopeCollins Jul 30, 2025
6439b05
Merge branch 'main' into dham/abstract_reduced_functional
JHopeCollins Aug 4, 2025
c2d2940
adjoint test fixes
JHopeCollins Aug 5, 2025
7bf2d47
set riesz map correctly in shape optimisation demo
JHopeCollins Aug 5, 2025
b6bff26
update fwi demo tolerance
JHopeCollins Aug 5, 2025
e6d6731
Trigger CI
JHopeCollins Aug 6, 2025
62f4921
test hessian for solve and nlvs
JHopeCollins Aug 6, 2025
3740a20
fwi tolerances
JHopeCollins Aug 6, 2025
ce61e92
fwi demo tolerances
JHopeCollins Aug 11, 2025
52a8957
Merge branch 'dham/abstract_reduced_functional' of github.com:firedra…
JHopeCollins Aug 13, 2025
eb8599b
Merge branch 'main' into dham/abstract_reduced_functional
JHopeCollins Aug 13, 2025
495ce87
_ad_init_zero for EnsembleFunction
JHopeCollins Aug 13, 2025
9723fea
Update .github/workflows/core.yml
JHopeCollins Aug 13, 2025
f7da8f8
Update .github/workflows/core.yml
JHopeCollins Aug 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/core.yml
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ jobs:
pip install -U pip
pip install --group ./firedrake-repo/pyproject.toml:ci


firedrake-clean
pip list

Expand Down
25 changes: 16 additions & 9 deletions demos/full_waveform_inversion/full_waveform_inversion.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ built over the ``my_ensemble.comm`` (spatial) communicator.
dt = 0.03 # time step in seconds
final_time = 0.6 # final time in seconds
nx, ny = 15, 15
ftol = 0.9 # optimisation tolerance
else:
dt = 0.002 # time step in seconds
final_time = 1.0 # final time in seconds
nx, ny = 80, 80
ftol = 1e-2 # optimisation tolerance

mesh = UnitSquareMesh(nx, ny, comm=my_ensemble.comm)

Expand Down Expand Up @@ -278,21 +280,28 @@ To have the step 4, we need first to tape the forward problem. That is done by c

We now instantiate :class:`~.EnsembleReducedFunctional`::

J_hat = EnsembleReducedFunctional(J_val, Control(c_guess), my_ensemble)
J_hat = EnsembleReducedFunctional(J_val,
Control(c_guess, riesz_map="l2"),
my_ensemble)

which enables us to recompute :math:`J` and its gradient :math:`\nabla_{\mathtt{c\_guess}} J`,
where the :math:`J_s` and its gradients :math:`\nabla_{\mathtt{c\_guess}} J_s` are computed in parallel
based on the ``my_ensemble`` configuration.


**Steps 4-6**: The instance of the :class:`~.EnsembleReducedFunctional`, named ``J_hat``,
is then passed as an argument to the ``minimize`` function::
is then passed as an argument to the ``minimize`` function. The default ``minimize`` function
uses ``scipy.minimize``, and wraps the ``ReducedFunctional`` in a ``ReducedFunctionalNumPy``
that handles transferring data between Firedrake and numpy data structures. However, because
we have a custom ``ReducedFunctional``, we need to do this ourselves::

c_optimised = minimize(J_hat, method="L-BFGS-B", options={"disp": True, "maxiter": 1},
bounds=(1.5, 2.0), derivative_options={"riesz_representation": 'l2'}
)
from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy
Jnumpy = ReducedFunctionalNumPy(J_hat)

The ``minimize`` function executes the optimisation algorithm until the stopping criterion (``maxiter``) is met.
c_optimised = minimize(Jnumpy, method="L-BFGS-B", options={"disp": True, "ftol": ftol},
bounds=(1.5, 2.0))

The ``minimize`` function executes the optimisation algorithm until the stopping criterion (``ftol``) is met.
For 20 iterations, the predicted velocity model is shown in the following figure.

.. image:: c_predicted.png
Expand All @@ -303,9 +312,7 @@ For 20 iterations, the predicted velocity model is shown in the following figure
.. warning::

The ``minimize`` function uses the SciPy library for optimisation. However, for scenarios that require higher
levels of spatial parallelism, you should assess whether SciPy is the most suitable option for your problem.
SciPy's optimisation algorithm is not inner-product-aware. Therefore, we configure the options with
``derivative_options={"riesz_representation": 'l2'}`` to account for this requirement.
levels of spatial parallelism, you should assess whether SciPy is the most suitable option for your problem such as the pyadjoint's TAOSolver.

.. note::

Expand Down
8 changes: 4 additions & 4 deletions demos/shape_optimization/shape_optimization.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ and evaluate the objective function::

We now turn the objective function into a reduced function so that pyadjoint
(and UFL shape differentiation capability) can automatically compute shape
gradients, that is, directions of steepest ascent::
gradients, that is, directions of steepest ascent. We also set the relevant
Riesz map for this problem::

Jred = ReducedFunctional(J, Control(dT))
Jred = ReducedFunctional(J, Control(dT, riesz_map="H1"))
stop_annotating()

We now have all the ingredients to implement a basic steepest descent shape
Expand All @@ -84,8 +85,7 @@ optimization algorithm with fixed step size.::
File.write(mesh.coordinates)

# compute the gradient (steepest ascent)
opts = {"riesz_representation": "H1"}
gradJ = Jred.derivative(options=opts)
gradJ = Jred.derivative(apply_riesz=True)

# update domain
dT -= 0.2*gradJ
Expand Down
2 changes: 1 addition & 1 deletion firedrake/adjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from firedrake.adjoint_utils import get_solve_blocks # noqa F401

from pyadjoint.verification import taylor_test, taylor_to_dict # noqa F401
from pyadjoint.drivers import compute_gradient, compute_hessian # noqa F401
from pyadjoint.drivers import compute_gradient, compute_derivative, compute_hessian # noqa F401
from pyadjoint.adjfloat import AdjFloat # noqa F401
from pyadjoint.control import Control # noqa F401
from pyadjoint import IPOPTSolver, ROLSolver, MinimizationProblem, \
Expand Down
97 changes: 72 additions & 25 deletions firedrake/adjoint/ensemble_reduced_functional.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from pyadjoint import ReducedFunctional
from pyadjoint.reduced_functional import AbstractReducedFunctional, ReducedFunctional
from pyadjoint.enlisting import Enlist
from pyop2.mpi import MPI

import firedrake
from firedrake.function import Function
from firedrake.cofunction import Cofunction


class EnsembleReducedFunctional(ReducedFunctional):
class EnsembleReducedFunctional(AbstractReducedFunctional):
"""Enable solving simultaneously reduced functionals in parallel.

Consider a functional :math:`J` and its gradient :math:`\\dfrac{dJ}{dm}`,
Expand Down Expand Up @@ -34,7 +35,7 @@ class EnsembleReducedFunctional(ReducedFunctional):

Parameters
----------
J : pyadjoint.OverloadedType
functional : pyadjoint.OverloadedType
An instance of an OverloadedType, usually :class:`pyadjoint.AdjFloat`.
This should be the functional that we want to reduce.
control : pyadjoint.Control or list of pyadjoint.Control
Expand Down Expand Up @@ -86,28 +87,40 @@ class EnsembleReducedFunctional(ReducedFunctional):
works, please refer to the `Firedrake manual
<https://www.firedrakeproject.org/parallelism.html#ensemble-parallelism>`_.
"""
def __init__(self, J, control, ensemble, scatter_control=True,
gather_functional=None, derivative_components=None,
scale=1.0, tape=None, eval_cb_pre=lambda *args: None,
def __init__(self, functional, control, ensemble, scatter_control=True,
gather_functional=None,
derivative_components=None,
scale=1.0, tape=None,
eval_cb_pre=lambda *args: None,
eval_cb_post=lambda *args: None,
derivative_cb_pre=lambda controls: controls,
derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components,
hessian_cb_pre=lambda *args: None, hessian_cb_post=lambda *args: None):
super(EnsembleReducedFunctional, self).__init__(
J, control, derivative_components=derivative_components,
scale=scale, tape=tape, eval_cb_pre=eval_cb_pre,
eval_cb_post=eval_cb_post, derivative_cb_pre=derivative_cb_pre,
hessian_cb_pre=lambda *args: None,
hessian_cb_post=lambda *args: None):
self.local_reduced_functional = ReducedFunctional(
functional, control,
derivative_components=derivative_components,
scale=scale, tape=tape,
eval_cb_pre=eval_cb_pre,
eval_cb_post=eval_cb_post,
derivative_cb_pre=derivative_cb_pre,
derivative_cb_post=derivative_cb_post,
hessian_cb_pre=hessian_cb_pre, hessian_cb_post=hessian_cb_post)
hessian_cb_pre=hessian_cb_pre,
hessian_cb_post=hessian_cb_post
)

self.ensemble = ensemble
self.scatter_control = scatter_control
self.gather_functional = gather_functional

@property
def controls(self):
return self.local_reduced_functional.controls

def _allgather_J(self, J):
if isinstance(J, float):
vals = self.ensemble.ensemble_comm.allgather(J)
elif isinstance(J, firedrake.Function):
elif isinstance(J, Function):
# allgather not implemented in ensemble.py
vals = []
for i in range(self.ensemble.ensemble_comm.size):
Expand All @@ -134,30 +147,31 @@ def __call__(self, values):
The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`.

"""
local_functional = super(EnsembleReducedFunctional, self).__call__(values)
local_functional = self.local_reduced_functional(values)
ensemble_comm = self.ensemble.ensemble_comm
if self.gather_functional:
controls_g = self._allgather_J(local_functional)
total_functional = self.gather_functional(controls_g)
# if gather_functional is None then we do a sum
elif isinstance(local_functional, float):
total_functional = ensemble_comm.allreduce(sendobj=local_functional, op=MPI.SUM)
elif isinstance(local_functional, firedrake.Function):
elif isinstance(local_functional, Function):
total_functional = type(local_functional)(local_functional.function_space())
total_functional = self.ensemble.allreduce(local_functional, total_functional)
else:
raise NotImplementedError("This type of functional is not supported.")
return total_functional

def derivative(self, adj_input=1.0, options=None):
def derivative(self, adj_input=1.0, apply_riesz=False):
"""Compute derivatives of a functional with respect to the control parameters.

Parameters
----------
adj_input : float
The adjoint input.
options : dict
Additional options for the derivative computation.
apply_riesz: bool
If True, apply the Riesz map of each control in order to return
a primal gradient rather than a derivative in the dual space.

Returns
-------
Expand All @@ -171,29 +185,62 @@ def derivative(self, adj_input=1.0, options=None):

if self.gather_functional:
dJg_dmg = self.gather_functional.derivative(adj_input=adj_input,
options=options)
apply_riesz=False)
i = self.ensemble.ensemble_comm.rank
adj_input = dJg_dmg[i]

dJdm_local = super(EnsembleReducedFunctional, self).derivative(adj_input=adj_input, options=options)
dJdm_local = self.local_reduced_functional.derivative(adj_input=adj_input,
apply_riesz=apply_riesz)

if self.scatter_control:
dJdm_local = Enlist(dJdm_local)
dJdm_total = []

for dJdm in dJdm_local:
if not isinstance(dJdm, (firedrake.Function, float)):
raise NotImplementedError("This type of gradient is not supported.")
if not isinstance(dJdm, (Cofunction, Function, float)):
raise NotImplementedError(
f"Gradients of type {type(dJdm).__name__} are not supported.")

dJdm_total.append(
self.ensemble.allreduce(dJdm, type(dJdm)(dJdm.function_space()))
if isinstance(dJdm, firedrake.Function)
if isinstance(dJdm, (Cofunction, Function))
else self.ensemble.ensemble_comm.allreduce(sendobj=dJdm, op=MPI.SUM)
)
return dJdm_local.delist(dJdm_total)
return dJdm_local

def hessian(self, m_dot, options=None):
def tlm(self, m_dot):
"""Return the action of the tangent linear model of the functional.

The tangent linear model is evaluated w.r.t. the control on a vector
m_dot, around the last supplied value of the control.

Parameters
----------
m_dot : pyadjoint.OverloadedType
The direction in which to compute the action of the tangent linear model.

Returns
-------
pyadjoint.OverloadedType: The action of the tangent linear model in the
direction m_dot. Should be an instance of the same type as the functional.
"""
local_tlm = self.local_reduced_functional.tlm(m_dot)
ensemble_comm = self.ensemble.ensemble_comm
if self.gather_functional:
mdot_g = self._allgather_J(local_tlm)
total_tlm = self.gather_functional.tlm(mdot_g)
# if gather_functional is None then we do a sum
elif isinstance(local_tlm, float):
total_tlm = ensemble_comm.allreduce(sendobj=local_tlm, op=MPI.SUM)
elif isinstance(local_tlm, Function):
total_tlm = type(local_tlm)(local_tlm.function_space())
total_tlm = self.ensemble.allreduce(local_tlm, total_tlm)
else:
raise NotImplementedError("This type of functional is not supported.")
return total_tlm

def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False):
"""The Hessian is not yet implemented for ensemble reduced functional.

Raises:
Expand Down
20 changes: 16 additions & 4 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,15 @@ def _init_solver_parameters(self, args, kwargs):
self.assemble_kwargs = {}

def __str__(self):
return "solve({} = {})".format(ufl2unicode(self.lhs),
ufl2unicode(self.rhs))
try:
lhs_string = ufl2unicode(self.lhs)
except AttributeError:
lhs_string = str(self.lhs)
try:
rhs_string = ufl2unicode(self.rhs)
except AttributeError:
rhs_string = str(self.rhs)
return "solve({} = {})".format(lhs_string, rhs_string)

def _create_F_form(self):
# Process the equation forms, replacing values with checkpoints,
Expand Down Expand Up @@ -742,7 +749,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
c = block_variable.output
c_rep = block_variable.saved_output

if isinstance(c, firedrake.Function):
if isinstance(c, (firedrake.Function, firedrake.Cofunction)):
trial_function = firedrake.TrialFunction(c.function_space())
elif isinstance(c, firedrake.Constant):
mesh = F_form.ufl_domain()
Expand Down Expand Up @@ -779,7 +786,12 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
replace_map[self.func] = self.get_outputs()[0].saved_output
dFdm = replace(dFdm, replace_map)

dFdm = dFdm * adj_sol
if isinstance(dFdm, firedrake.Argument):
# Corner case. Should be fixed more permanently upstream in UFL.
# See: https://github.com/FEniCS/ufl/issues/395
dFdm = ufl.Action(dFdm, adj_sol)
else:
dFdm = dFdm * adj_sol
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)

return dFdm
Expand Down
12 changes: 2 additions & 10 deletions firedrake/adjoint_utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pyadjoint.adjfloat import AdjFloat
from pyadjoint.tape import get_working_tape, annotate_tape
from pyadjoint.overloaded_type import OverloadedType, create_overloaded_object
from pyadjoint.reduced_functional_numpy import gather

from firedrake.functionspace import FunctionSpace
from firedrake.adjoint_utils.blocks import ConstantAssignBlock
Expand Down Expand Up @@ -58,15 +57,8 @@ def wrapper(self, *args, **kwargs):

return wrapper

def get_derivative(self, options={}):
return self._ad_convert_type(self.adj_value, options=options)

def _ad_convert_type(self, value, options={}):
if value is None:
# TODO: Should the default be 0 constant here or return just None?
return type(self)(numpy.zeros(self.ufl_shape))
value = gather(value)
return self._constant_from_values(value)
def _ad_init_zero(self, dual=False):
return type(self)(numpy.zeros(self.ufl_shape))

def _ad_function_space(self, mesh):
element = self.ufl_element()
Expand Down
7 changes: 7 additions & 0 deletions firedrake/adjoint_utils/ensemble_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def _ad_dot(self, other, options=None):
def _ad_convert_riesz(self, value, options=None):
raise NotImplementedError

def _ad_init_zero(self, dual=False):
from firedrake import EnsembleFunction, EnsembleCofunction
if dual:
return EnsembleCofunction(self.function_space().dual())
else:
return EnsembleFunction(self.function_space())

def _ad_create_checkpoint(self):
if disk_checkpointing():
raise NotImplementedError(
Expand Down
Loading
Loading