You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
24
-
Tuple containing the arguments of the linear form associated with the ML operator,
25
-
i.e. the arguments with respect to which the ML operator is linear. Those arguments
26
-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
27
-
as a result of taking the action on a given function.
24
+
Tuple containing the arguments of the linear form associated with the ML operator,
25
+
i.e. the arguments with respect to which the ML operator is linear. Those arguments can
26
+
be ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
27
+
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
28
+
of taking the action on a given function. If argument slots are not provided, then they will
29
+
be generated in the :class:`.AbstractExternalOperator` constructor.
28
30
operator_data : dict
29
-
Dictionary to stash external data specific to the ML operator. This dictionary must
30
-
at least contain the following:
31
-
(i) 'model': The machine learning model implemented in the ML framework considered.
32
-
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
33
-
Other strategies can also be considered by subclassing the :class:`.MLOperator` class.
31
+
Dictionary to stash external data specific to the ML operator. This dictionary must
32
+
at least contain the following:
33
+
(i) 'model': The machine learning model implemented in the ML framework considered.
34
+
(ii) 'inputs_format': The format of the inputs to the ML model: ``0`` for models acting globally
35
+
on the inputs, ``1`` when acting locally/pointwise on the inputs.
36
+
Other strategies can also be considered by subclassing the :class:`.MLOperator` class.
"""External operator class representing machine learning models implemented in JAX.
45
+
"""
46
+
External operator class representing machine learning models implemented in JAX.
46
47
47
48
The :class:`.JaxOperator` allows users to embed machine learning models implemented in JAX
48
-
into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator` is
49
-
delegated to the specified JAX model. Similarly, differentiation through the :class:`.JaxOperator`
50
-
class is achieved using JAX differentiation on the JAX model associated with the :class:`.JaxOperator` object.
49
+
into PDE systems implemented in Firedrake. The actual evaluation of the :class:`.JaxOperator`
50
+
is delegated to the specified JAX model. Similarly, differentiation through the
51
+
:class:`.JaxOperator` is achieved using JAX differentiation on the associated JAX model.
51
52
52
53
Parameters
53
54
----------
54
55
*operands
55
-
Operands of the :class:`.JaxOperator`.
56
+
Operands of the :class:`.JaxOperator`.
56
57
function_space
57
-
The function space the ML operator is mapping to.
58
+
The function space the ML operator is mapping to.
58
59
derivatives
59
-
Tuple specifiying the derivative multiindex.
60
-
*argument_slots
61
-
Tuple containing the arguments of the linear form associated with the ML operator,
62
-
i.e. the arguments with respect to which the ML operator is linear. Those arguments
63
-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
64
-
as a result of taking the action on a given function.
60
+
Tuple specifying the derivative multi-index.
61
+
argument_slots
62
+
Tuple containing the arguments of the linear form associated with the ML operator,
63
+
i.e., the arguments with respect to which the ML operator is linear. These arguments
64
+
can be ``ufl.argument.BaseArgument`` objects, as a result of differentiation,
65
+
or both ``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` objects,
66
+
as a result of taking the action on a given function.
65
67
operator_data
66
-
Dictionary to stash external data specific to the ML operator. This dictionary must
67
-
at least contain the following:
68
-
(i) 'model': The machine learning model implemented in JaX
69
-
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
70
-
Other strategies can also be considered by subclassing the :class:`.JaxOperator` class.
68
+
Dictionary to stash external data specific to the ML operator. This dictionary must
69
+
contain the following:
70
+
(i) ``'model'`` : The machine learning model implemented in JaX.
71
+
(ii) ``'model'`` : The format of the inputs to the ML model: ``0`` for models acting
72
+
globally on the inputs. ``1`` for models acting locally/pointwise on the inputs.
73
+
Other strategies can also be considered by subclassing the :class:`.JaxOperator` class.
*argument_slots : ufl.coefficient.BaseCoefficient or ufl.argument.BaseArgument
49
-
Tuple containing the arguments of the linear form associated with the ML operator,
50
-
i.e. the arguments with respect to which the ML operator is linear. Those arguments
51
-
can be ufl.Argument objects, as a result of differentiation, or ufl.Coefficient objects,
52
-
as a result of taking the action on a given function.
49
+
Tuple containing the arguments of the linear form associated with the ML operator, i.e. the
50
+
arguments with respect to which the ML operator is linear. Those arguments can be
51
+
``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
52
+
``ufl.coefficient.BaseCoefficient`` and ``ufl.argument.BaseArgument`` object, as a result
53
+
of taking the action on a given function.
53
54
operator_data : dict
54
-
Dictionary to stash external data specific to the ML operator. This dictionary must
55
-
at least contain the following:
56
-
(i) 'model': The machine learning model implemented in PyTorch.
57
-
(ii) 'inputs_format': The format of the inputs to the ML model: `0` for models acting globally on the inputs, `1` when acting locally/pointwise on the inputs.
58
-
Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class.
55
+
Dictionary to stash external data specific to the ML operator. This dictionary must
56
+
at least contain the following:
57
+
(i) ``'model'``: The machine learning model implemented in PyTorch.
58
+
(ii) ``'inputs_format'``: The format of the inputs to the ML model: ``0`` for models acting globally
59
+
on the inputs, ``1`` when acting locally/pointwise on the inputs.
60
+
Other strategies can also be considered by subclassing the :class:`.PytorchOperator` class.
0 commit comments