Skip to content

Commit 7ea3dfc

Browse files
committed
Merge branch 'master' into pbrubeck/fix/base-form-tensor
2 parents 5916217 + a84c071 commit 7ea3dfc

File tree

15 files changed

+141
-81
lines changed

15 files changed

+141
-81
lines changed

docs/source/images/sin_mesh.png

86.6 KB
Loading

docs/source/mesh-coordinates.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,30 @@ Or simply:
8383
8484
new_mesh = Mesh(Function(f))
8585
86+
Immersing a mesh in higher dimensional space
87+
--------------------------------------------
88+
89+
A useful special case of creating a mesh on modified coordinates is to immerse
90+
a lower dimensional mesh in a higher dimension, for example to create a mesh of
91+
a two-dimensional manifold immersed in 3D.
92+
93+
This is accomplished by setting the value dimension of the new
94+
:py:func:`~.VectorFunctionSpace` to that of the space in which it should be
95+
immersed. For example, a mesh of square bent into a sine wave using
96+
linear (flat) elements can be created with:
97+
98+
.. literalinclude:: ../../tests/firedrake/regression/test_mesh_generation.py
99+
:language: python3
100+
:dedent:
101+
:start-after: start_immerse
102+
:end-before: end_immerse
103+
104+
105+
.. figure:: images/sin_mesh.png
106+
:align: center
107+
108+
A sine-wave shaped triangle mesh immersed in three-dimensional space.
109+
86110

87111
Replacing the mesh geometry of an existing function
88112
---------------------------------------------------

firedrake/external_operators/abstract_external_operators.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,22 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
4747
Parameters
4848
----------
4949
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
50-
Operands of the external operator.
50+
Operands of the external operator.
5151
function_space : firedrake.functionspaceimpl.WithGeometryBase
52-
The function space the external operator is mapping to.
52+
The function space the external operator is mapping to.
5353
derivatives : tuple
54-
Tuple specifiying the derivative multiindex.
54+
Tuple specifiying the derivative multiindex.
5555
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
56-
Tuple containing the arguments of the linear form associated with the external operator,
57-
i.e. the arguments with respect to which the external operator is linear. Those arguments
58-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
59-
as a result of taking the action on a given function.
56+
Tuple containing the arguments of the linear form associated with the external operator,
57+
i.e. the arguments with respect to which the external operator is linear. Those arguments can
58+
be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
59+
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
60+
of taking the action on a given function.
6061
operator_data : dict
61-
Dictionary containing the data of the external operator, i.e. the external data
62-
specific to the external operator subclass considered. This dictionary will be passed on
63-
over the UFL symbolic reconstructions making the operator data accessible to the external operators
64-
arising from symbolic operations on the original operator, such as the Jacobian of the external operator.
62+
Dictionary containing the data of the external operator, i.e. the external data
63+
specific to the external operator subclass considered. This dictionary will be passed on
64+
over the UFL symbolic reconstructions making the operator data accessible to the external operators
65+
arising from symbolic operations on the original operator, such as the Jacobian of the external operator.
6566
"""
6667
from firedrake_citations import Citations
6768
Citations().register("Bouziani2021")

firedrake/external_operators/ml_operator.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,25 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
1515
Parameters
1616
----------
1717
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
18-
Operands of the ML operator.
18+
Operands of the ML operator.
1919
function_space : firedrake.functionspaceimpl.WithGeometryBase
20-
The function space the ML operator is mapping to.
20+
The function space the ML operator is mapping to.
2121
derivatives : tuple
22-
Tuple specifiying the derivative multiindex.
22+
Tuple specifiying the derivative multiindex.
2323
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
24-
Tuple containing the arguments of the linear form associated with the ML operator,
25-
i.e. the arguments with respect to which the ML operator is linear. Those arguments
26-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
27-
as a result of taking the action on a given function.
24+
Tuple containing the arguments of the linear form associated with the ML operator,
25+
i.e. the arguments with respect to which the ML operator is linear. Those arguments can
26+
be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
27+
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
28+
of taking the action on a given function. If argument slots are not provided, then they will
29+
be generated in the :class:`.AbstractExternalOperator` constructor.
2830
operator_data : dict
29-
Dictionary to stash external data specific to the ML operator. This dictionary must
30-
at least contain the following:
31-
(i) 'model': The machine learning model implemented in the ML framework considered.
32-
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
33-
Other strategies can also be considered by subclassing the :class:`.MLOperator` class.
31+
Dictionary to stash external data specific to the ML operator. This dictionary must
32+
at least contain the following:
33+
(i) 'model': The machine learning model implemented in the ML framework considered.
34+
(ii) 'inputs_format': The format of the inputs to the ML model: ``0`` for models acting globally
35+
on the inputs, ``1`` when acting locally/pointwise on the inputs.
36+
Other strategies can also be considered by subclassing the :class:`.MLOperator` class.
3437
"""
3538
AbstractExternalOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
3639
argument_slots=argument_slots, operator_data=operator_data)

firedrake/ml/jax/fem_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def bwd(self, _, grad_output: "jax.Array") -> "jax.Array":
9292
adj_input = float(adj_input)
9393

9494
# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
95-
adj_output = self.F.derivative(adj_input=adj_input)
95+
adj_output = self.F.derivative(adj_input=adj_input, options={'riesz_representation': None})
9696

9797
# Tuplify adjoint output
9898
adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output

firedrake/ml/jax/ml_operator.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,36 +39,40 @@ def __init__(
3939
*operands: Union[ufl.core.expr.Expr, ufl.form.BaseForm],
4040
function_space: WithGeometryBase,
4141
derivatives: Optional[tuple] = None,
42-
argument_slots: Optional[tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]]],
42+
argument_slots: tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]] = (),
4343
operator_data: Optional[dict] = {}
4444
):
45-
"""External operator class representing machine learning models implemented in JAX.
45+
"""
46+
External operator class representing machine learning models implemented in JAX.
4647
4748
The :class:`.JaxOperator` allows users to embed machine learning models implemented in JAX
48-
into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator` is
49-
delegated to the specified JAX model. Similarly, differentiation through the :class:`.JaxOperator`
50-
class is achieved using JAX differentiation on the JAX model associated with the :class:`.JaxOperator` object.
49+
into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator`
50+
is delegated to the specified JAX model. Similarly, differentiation through the
51+
:class:`.JaxOperator` is achieved using JAX differentiation on the associated JAX model.
5152
5253
Parameters
5354
----------
5455
*operands
55-
Operands of the :class:`.JaxOperator`.
56+
Operands of the :class:`.JaxOperator`.
5657
function_space
57-
The function space the ML operator is mapping to.
58+
The function space the ML operator is mapping to.
5859
derivatives
59-
Tuple specifiying the derivative multiindex.
60-
*argument_slots
61-
Tuple containing the arguments of the linear form associated with the ML operator,
62-
i.e. the arguments with respect to which the ML operator is linear. Those arguments
63-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
64-
as a result of taking the action on a given function.
60+
Tuple specifying the derivative multi-index.
61+
argument_slots
62+
Tuple containing the arguments of the linear form associated with the ML operator,
63+
i.e., the arguments with respect to which the ML operator is linear. These arguments
64+
can be ``ufl.argument.BaseArgument`` objects, as a result of differentiation,
65+
or both ``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` objects,
66+
as a result of taking the action on a given function.
6567
operator_data
66-
Dictionary to stash external data specific to the ML operator. This dictionary must
67-
at least contain the following:
68-
(i) 'model': The machine learning model implemented in JaX
69-
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
70-
Other strategies can also be considered by subclassing the :class:`.JaxOperator` class.
68+
Dictionary to stash external data specific to the ML operator. This dictionary must
69+
contain the following:
70+
(i) ``'model'`` : The machine learning model implemented in JaX.
71+
(ii) ``'model'`` : The format of the inputs to the ML model: ``0`` for models acting
72+
globally on the inputs. ``1`` for models acting locally/pointwise on the inputs.
73+
Other strategies can also be considered by subclassing the :class:`.JaxOperator` class.
7174
"""
75+
7276
MLOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
7377
argument_slots=argument_slots, operator_data=operator_data)
7478

@@ -90,8 +94,7 @@ def _pre_forward_callback(self, *operands: Union[Function, Cofunction], unsqueez
9094

9195
def _post_forward_callback(self, y_P: "jax.Array") -> Union[Function, Cofunction]:
9296
"""Callback function to convert the JAX output of the ML model to a Firedrake function."""
93-
space = self.ufl_function_space()
94-
return from_jax(y_P, space)
97+
return from_jax(y_P, self.ufl_function_space())
9598

9699
# -- JAX routines for computing AD-based quantities -- #
97100

firedrake/ml/pytorch/fem_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def backward(ctx, grad_output):
8383
adj_input = float(adj_input)
8484

8585
# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
86-
adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": "l2"})
86+
adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": None})
8787

8888
# Tuplify adjoint output
8989
adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output

firedrake/ml/pytorch/ml_operator.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,24 @@ class is achieved via the `torch.autograd` module, which provides automatic diff
4040
Parameters
4141
----------
4242
*operands : ufl.core.expr.Expr or ufl.form.BaseForm
43-
Operands of the :class:`.PytorchOperator`.
43+
Operands of the :class:`.PytorchOperator`.
4444
function_space : firedrake.functionspaceimpl.WithGeometryBase
45-
The function space the ML operator is mapping to.
45+
The function space the ML operator is mapping to.
4646
derivatives : tuple
47-
Tuple specifiying the derivative multiindex.
47+
Tuple specifiying the derivative multiindex.
4848
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
49-
Tuple containing the arguments of the linear form associated with the ML operator,
50-
i.e. the arguments with respect to which the ML operator is linear. Those arguments
51-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
52-
as a result of taking the action on a given function.
49+
Tuple containing the arguments of the linear form associated with the ML operator, i.e. the
50+
arguments with respect to which the ML operator is linear. Those arguments can be
51+
``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
52+
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
53+
of taking the action on a given function.
5354
operator_data : dict
54-
Dictionary to stash external data specific to the ML operator. This dictionary must
55-
at least contain the following:
56-
(i) 'model': The machine learning model implemented in PyTorch.
57-
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
58-
Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class.
55+
Dictionary to stash external data specific to the ML operator. This dictionary must
56+
at least contain the following:
57+
(i) ``'model'``: The machine learning model implemented in PyTorch.
58+
(ii) ``'inputs_format'``: The format of the inputs to the ML model: ``0`` for models acting globally
59+
on the inputs, ``1`` when acting locally/pointwise on the inputs.
60+
Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class.
5961
"""
6062
MLOperator.__init__(self, *operands, function_space=function_space, derivatives=derivatives,
6163
argument_slots=argument_slots, operator_data=operator_data)
@@ -98,8 +100,7 @@ def _pre_forward_callback(self, *operands, unsqueeze=False):
98100

99101
def _post_forward_callback(self, y_P):
100102
"""Callback function to convert the PyTorch output of the ML model to a Firedrake function."""
101-
space = self.ufl_function_space()
102-
return from_torch(y_P, space)
103+
return from_torch(y_P, self.ufl_function_space())
103104

104105
# -- PyTorch routines for computing AD based quantities via `torch.autograd.functional` -- #
105106

pyop2/compilation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def sniff_compiler(exe, comm=mpi.COMM_WORLD):
155155
# Find the name of the compiler family
156156
if output.startswith("gcc") or output.startswith("g++"):
157157
name = "GNU"
158-
elif output.startswith("clang"):
158+
elif output.startswith("clang") or output.startswith("Homebrew clang"):
159159
name = "clang"
160160
elif output.startswith("Apple LLVM") or output.startswith("Apple clang"):
161161
name = "clang"

scripts/firedrake-install

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,7 @@ def build_and_install_jax():
13031303
"""Install JAX for a CPU or CUDA backend."""
13041304
log.info("Installing JAX (backend: %s)" % args.jax)
13051305
version_name = "jax" if args.jax == "cpu" else "jax[cuda12]"
1306-
run_pip_install([version_name])
1306+
run_pip_install([version_name] + ["jaxlib"] + ["ml_dtypes"] + ["opt_einsum"])
13071307

13081308

13091309
def build_and_install_slepc():

0 commit comments

Comments
 (0)