Skip to content

Commit a0015c3

Browse files
add exhaustive doc for condition module (#629)
1 parent f3ccfd4 commit a0015c3

File tree

6 files changed

+370
-250
lines changed

6 files changed

+370
-250
lines changed

pina/condition/condition.py

Lines changed: 80 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,91 @@
11
"""Module for the Condition class."""
22

3-
import warnings
43
from .data_condition import DataCondition
54
from .domain_equation_condition import DomainEquationCondition
65
from .input_equation_condition import InputEquationCondition
76
from .input_target_condition import InputTargetCondition
8-
from ..utils import custom_warning_format
97

10-
# Set the custom format for warnings
11-
warnings.formatwarning = custom_warning_format
12-
warnings.filterwarnings("always", category=DeprecationWarning)
138

9+
class Condition:
10+
"""
11+
The :class:`Condition` class is a core component of the PINA framework that
12+
provides a unified interface to define heterogeneous constraints that must
13+
be satisfied by a :class:`~pina.problem.abstract_problem.AbstractProblem`.
1414
15-
def warning_function(new, old):
16-
"""Handle the deprecation warning.
15+
It encapsulates all types of constraints - physical, boundary, initial, or
16+
data-driven - that the solver must satisfy during training. The specific
17+
behavior is inferred from the arguments passed to the constructor.
1718
18-
:param new: Object to use instead of the old one.
19-
:type new: str
20-
:param old: Object to deprecate.
21-
:type old: str
22-
"""
23-
warnings.warn(
24-
f"'{old}' is deprecated and will be removed "
25-
f"in future versions. Please use '{new}' instead.",
26-
DeprecationWarning,
27-
)
19+
Multiple types of conditions can be used within the same problem, allowing
20+
for a high degree of flexibility in defining complex problems.
2821
22+
The :class:`Condition` class behavior specializes internally based on the
23+
arguments provided during instantiation. Depending on the specified keyword
24+
arguments, the class automatically selects the appropriate internal
25+
implementation.
2926
30-
class Condition:
31-
"""
32-
Represents constraints (such as physical equations, boundary conditions,
33-
etc.) that must be satisfied in a given problem. Condition objects are used
34-
to formulate the PINA
35-
:class:`~pina.problem.abstract_problem.AbstractProblem` object.
3627
37-
There are different types of conditions:
28+
Available `Condition` types:
3829
3930
- :class:`~pina.condition.input_target_condition.InputTargetCondition`:
40-
Defined by specifying both the input and the target of the condition. In
41-
this case, the model is trained to produce the target given the input. The
42-
input and output data must be one of the :class:`torch.Tensor`,
43-
:class:`~pina.label_tensor.LabelTensor`,
44-
:class:`~torch_geometric.data.Data`, or :class:`~pina.graph.Graph`.
45-
Different implementations exist depending on the type of input and target.
46-
For more details, see
47-
:class:`~pina.condition.input_target_condition.InputTargetCondition`.
31+
represents a supervised condition defined by both ``input`` and ``target``
32+
data. The model is trained to reproduce the ``target`` values given the
33+
``input``. Supported data types include :class:`torch.Tensor`,
34+
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or
35+
:class:`~torch_geometric.data.Data`.
36+
The class automatically selects the appropriate implementation based on
37+
the types of ``input`` and ``target``.
4838
4939
- :class:`~pina.condition.domain_equation_condition.DomainEquationCondition`
50-
: Defined by specifying both the domain and the equation of the condition.
51-
Here, the model is trained to minimize the equation residual by evaluating
52-
it at sampled points within the domain.
40+
: represents a general physics-informed condition defined by a ``domain``
41+
and an ``equation``. The model learns to minimize the equation residual
42+
through evaluations performed at points sampled from the specified domain.
5343
5444
- :class:`~pina.condition.input_equation_condition.InputEquationCondition`:
55-
Defined by specifying the input and the equation of the condition. In this
56-
case, the model is trained to minimize the equation residual by evaluating
57-
it at the provided input. The input must be either a
58-
:class:`~pina.label_tensor.LabelTensor` or a :class:`~pina.graph.Graph`.
59-
Different implementations exist depending on the type of input. For more
60-
details, see
61-
:class:`~pina.condition.input_equation_condition.InputEquationCondition`.
62-
63-
- :class:`~pina.condition.data_condition.DataCondition`:
64-
Defined by specifying only the input. In this case, the model is trained
65-
with an unsupervised custom loss while using the provided data during
66-
training. The input data must be one of :class:`torch.Tensor`,
67-
:class:`~pina.label_tensor.LabelTensor`,
68-
:class:`~torch_geometric.data.Data`, or :class:`~pina.graph.Graph`.
69-
Additionally, conditional variables can be provided when the model
70-
depends on extra parameters. These conditional variables must be either
71-
:class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`.
72-
Different implementations exist depending on the type of input.
73-
For more details, see
74-
:class:`~pina.condition.data_condition.DataCondition`.
45+
represents a general physics-informed condition defined by ``input``
46+
points and an ``equation``. The model learns to minimize the equation
47+
residual through evaluations performed at the provided ``input``.
48+
Supported data types for the ``input`` include
49+
:class:`~pina.label_tensor.LabelTensor` or :class:`~pina.graph.Graph`.
50+
The class automatically selects the appropriate implementation based on
51+
the types of the ``input``.
52+
53+
- :class:`~pina.condition.data_condition.DataCondition`: represents an
54+
unsupervised, data-driven condition defined by the ``input`` only.
55+
The model is trained using a custom unsupervised loss determined by the
56+
chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging the
57+
provided data during training. Optional ``conditional_variables`` can be
58+
specified when the model depends on additional parameters.
59+
Supported data types include :class:`torch.Tensor`,
60+
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or
61+
:class:`~torch_geometric.data.Data`.
62+
The class automatically selects the appropriate implementation based on
63+
the type of the ``input``.
64+
65+
.. note::
66+
67+
The user should always instantiate :class:`Condition` directly, without
68+
manually creating subclass instances. Please refer to the specific
69+
:class:`Condition` classes for implementation details.
7570
7671
:Example:
7772
7873
>>> from pina import Condition
79-
>>> condition = Condition(
80-
... input=input,
81-
... target=target
82-
... )
83-
>>> condition = Condition(
84-
... domain=location,
85-
... equation=equation
86-
... )
87-
>>> condition = Condition(
88-
... input=input,
89-
... equation=equation
90-
... )
91-
>>> condition = Condition(
92-
... input=data,
93-
... conditional_variables=conditional_variables
94-
... )
9574
75+
>>> # Example of InputTargetCondition signature
76+
>>> condition = Condition(input=input, target=target)
77+
78+
>>> # Example of DomainEquationCondition signature
79+
>>> condition = Condition(domain=domain, equation=equation)
80+
81+
>>> # Example of InputEquationCondition signature
82+
>>> condition = Condition(input=input, equation=equation)
83+
84+
>>> # Example of DataCondition signature
85+
>>> condition = Condition(input=data, conditional_variables=cond_vars)
9686
"""
9787

88+
# Combine all possible keyword arguments from the different Condition types
9889
__slots__ = list(
9990
set(
10091
InputTargetCondition.__slots__
@@ -106,46 +97,45 @@ class Condition:
10697

10798
def __new__(cls, *args, **kwargs):
10899
"""
109-
Instantiate the appropriate Condition object based on the keyword
110-
arguments passed.
100+
Instantiate the appropriate :class:`Condition` object based on the
101+
keyword arguments passed.
111102
112-
:raises ValueError: If no keyword arguments are passed.
103+
:param tuple args: The positional arguments (should be empty).
104+
:param dict kwargs: The keyword arguments corresponding to the
105+
parameters of the specific :class:`Condition` type to instantiate.
106+
:raises ValueError: If unexpected positional arguments are provided.
113107
:raises ValueError: If the keyword arguments are invalid.
114-
:return: The appropriate Condition object.
108+
:return: The appropriate :class:`Condition` object.
115109
:rtype: ConditionInterface
116110
"""
117-
111+
# Check keyword arguments
118112
if len(args) != 0:
119113
raise ValueError(
120114
"Condition takes only the following keyword "
121115
f"arguments: {Condition.__slots__}."
122116
)
123117

124-
# back-compatibility 0.1
125-
keys = list(kwargs.keys())
126-
if "location" in keys:
127-
kwargs["domain"] = kwargs.pop("location")
128-
warning_function(new="domain", old="location")
129-
130-
if "input_points" in keys:
131-
kwargs["input"] = kwargs.pop("input_points")
132-
warning_function(new="input", old="input_points")
133-
134-
if "output_points" in keys:
135-
kwargs["target"] = kwargs.pop("output_points")
136-
warning_function(new="target", old="output_points")
137-
118+
# Class specialization based on keyword arguments
138119
sorted_keys = sorted(kwargs.keys())
120+
121+
# Input - Target Condition
139122
if sorted_keys == sorted(InputTargetCondition.__slots__):
140123
return InputTargetCondition(**kwargs)
124+
125+
# Input - Equation Condition
141126
if sorted_keys == sorted(InputEquationCondition.__slots__):
142127
return InputEquationCondition(**kwargs)
128+
129+
# Domain - Equation Condition
143130
if sorted_keys == sorted(DomainEquationCondition.__slots__):
144131
return DomainEquationCondition(**kwargs)
132+
133+
# Data Condition
145134
if (
146135
sorted_keys == sorted(DataCondition.__slots__)
147136
or sorted_keys[0] == DataCondition.__slots__[0]
148137
):
149138
return DataCondition(**kwargs)
150139

140+
# Invalid keyword arguments
151141
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

pina/condition/condition_interface.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,115 +8,118 @@
88

99
class ConditionInterface(metaclass=ABCMeta):
1010
"""
11-
Abstract class which defines a common interface for all the conditions.
12-
It defined a common interface for all the conditions.
11+
Abstract base class for PINA conditions. All specific conditions must
12+
inherit from this interface.
1313
14+
Refer to :class:`pina.condition.condition.Condition` for a thorough
15+
description of all available conditions and how to instantiate them.
1416
"""
1517

1618
def __init__(self):
1719
"""
18-
Initialize the ConditionInterface object.
20+
Initialization of the :class:`ConditionInterface` class.
1921
"""
20-
2122
self._problem = None
2223

2324
@property
2425
def problem(self):
2526
"""
26-
Return the problem to which the condition is associated.
27+
Return the problem associated with this condition.
2728
28-
:return: Problem to which the condition is associated.
29+
:return: Problem associated with this condition.
2930
:rtype: ~pina.problem.abstract_problem.AbstractProblem
3031
"""
3132
return self._problem
3233

3334
@problem.setter
3435
def problem(self, value):
3536
"""
36-
Set the problem to which the condition is associated.
37+
Set the problem associated with this condition.
3738
38-
:param pina.problem.abstract_problem.AbstractProblem value: Problem to
39-
which the condition is associated
39+
:param pina.problem.abstract_problem.AbstractProblem value: The problem
40+
to associate with this condition
4041
"""
4142
self._problem = value
4243

4344
@staticmethod
4445
def _check_graph_list_consistency(data_list):
4546
"""
46-
Check the consistency of the list of Data/Graph objects. It performs
47-
the following checks:
48-
49-
1. All elements in the list must be of the same type (either Data or
50-
Graph).
51-
2. All elements in the list must have the same keys.
52-
3. The type of each tensor must be consistent across all elements in
53-
the list.
54-
4. If the tensor is a LabelTensor, the labels must be consistent across
55-
all elements in the list.
56-
57-
:param data_list: List of Data/Graph objects to check
58-
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
47+
Check the consistency of the list of Data | Graph objects.
48+
The following checks are performed:
49+
50+
- All elements in the list must be of the same type (either
51+
:class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`).
52+
53+
- All elements in the list must have the same keys.
54+
55+
- The data type of each tensor must be consistent across all elements.
5956
60-
:raises ValueError: If the input types are invalid.
57+
- If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels
58+
must also be consistent across all elements.
59+
60+
:param data_list: The list of Data | Graph objects to check.
61+
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
62+
:raises ValueError: If the input types are invalid.
6163
:raises ValueError: If all elements in the list do not have the same
6264
keys.
6365
:raises ValueError: If the type of each tensor is not consistent across
6466
all elements in the list.
6567
:raises ValueError: If the labels of the LabelTensors are not consistent
6668
across all elements in the list.
6769
"""
68-
69-
# If the data is a Graph or Data object, return (do not need to check
70-
# anything)
70+
# If the data is a Graph or Data object, perform no checks
7171
if isinstance(data_list, (Graph, Data)):
7272
return
7373

74-
# check all elements in the list are of the same type
74+
# Check all elements in the list are of the same type
7575
if not all(isinstance(i, (Graph, Data)) for i in data_list):
7676
raise ValueError(
77-
"Invalid input types. "
78-
"Please provide either Data or Graph objects."
77+
"Invalid input. Please, provide either Data or Graph objects."
7978
)
79+
80+
# Store the keys, data types and labels of the first element
8081
data = data_list[0]
81-
# Store the keys of the first element in the list
8282
keys = sorted(list(data.keys()))
83-
84-
# Store the type of each tensor inside first element Data/Graph object
8583
data_types = {name: tensor.__class__ for name, tensor in data.items()}
86-
87-
# Store the labels of each LabelTensor inside first element Data/Graph
88-
# object
8984
labels = {
9085
name: tensor.labels
9186
for name, tensor in data.items()
9287
if isinstance(tensor, LabelTensor)
9388
}
9489

95-
# Iterate over the list of Data/Graph objects
90+
# Iterate over the list of Data | Graph objects
9691
for data in data_list[1:]:
97-
# Check if the keys of the current element are the same as the first
98-
# element
92+
93+
# Check that all elements in the list have the same keys
9994
if sorted(list(data.keys())) != keys:
10095
raise ValueError(
10196
"All elements in the list must have the same keys."
10297
)
98+
99+
# Iterate over the tensors in the current element
103100
for name, tensor in data.items():
104-
# Check if the type of each tensor inside the current element
105-
# is the same as the first element
101+
# Check that the type of each tensor is consistent
106102
if tensor.__class__ is not data_types[name]:
107103
raise ValueError(
108104
f"Data {name} must be a {data_types[name]}, got "
109105
f"{tensor.__class__}"
110106
)
111-
# If the tensor is a LabelTensor, check if the labels are the
112-
# same as the first element
107+
108+
# Check that the labels of each LabelTensor are consistent
113109
if isinstance(tensor, LabelTensor):
114110
if tensor.labels != labels[name]:
115111
raise ValueError(
116112
"LabelTensor must have the same labels"
117113
)
118114

119115
def __getattribute__(self, name):
116+
"""
117+
Get an attribute from the object.
118+
119+
:param str name: The name of the attribute to get.
120+
:return: The requested attribute.
121+
:rtype: Any
122+
"""
120123
to_return = super().__getattribute__(name)
121124
if isinstance(to_return, (Graph, Data)):
122125
to_return = [to_return]

0 commit comments

Comments
 (0)