Skip to content

Commit c2f966e

Browse files
FilippoOlivodario-coscia
authored andcommitted
Improve doc condition
1 parent 240cbce commit c2f966e

File tree

5 files changed

+106
-69
lines changed

5 files changed

+106
-69
lines changed

pina/condition/condition.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
"""
2-
Condition module.
3-
"""
1+
"""Condition module."""
42

53
import warnings
64
from .data_condition import DataCondition
@@ -15,11 +13,12 @@
1513

1614

1715
def warning_function(new, old):
18-
"""
19-
Handle the deprecation warning.
16+
"""Handle the deprecation warning.
2017
21-
:param str new: Object to use instead of the old one.
22-
:param str old: Object to deprecate.
18+
:param new: Object to use instead of the old one.
19+
:type new: str
20+
:param old: Object to deprecate.
21+
:type old: str
2322
"""
2423
warnings.warn(
2524
f"'{old}' is deprecated and will be removed "
@@ -30,49 +29,58 @@ def warning_function(new, old):
3029

3130
class Condition:
3231
"""
33-
The class `Condition` is used to represent the constraints (physical
32+
The class ``Condition`` is used to represent the constraints (physical
3433
equations, boundary conditions, etc.) that should be satisfied in the
3534
problem at hand. Condition objects are used to formulate the
3635
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
3736
Conditions can be specified in four ways:
3837
39-
1. By specifying the input and output points of the condition; in such a
38+
1. By specifying the input and target of the condition; in such a
4039
case, the model is trained to produce the output points given the input
41-
points. Those points can either be torch.Tensor, LabelTensors, Graph
40+
points. Those points can either be torch.Tensor, LabelTensors, Graph.
41+
Based on the type of the input and target, there are different
42+
implementations of the condition. For more details, see
43+
:class:`~pina.condition.input_target_condition.InputTargetCondition`.
4244
43-
2. By specifying the location and the equation of the condition; in such
45+
2. By specifying the domain and the equation of the condition; in such
4446
a case, the model is trained to minimize the equation residual by
45-
evaluating it at some samples of the location.
47+
evaluating it at some samples of the domain.
4648
47-
3. By specifying the input points and the equation of the condition; in
49+
3. By specifying the input and the equation of the condition; in
4850
such a case, the model is trained to minimize the equation residual by
4951
evaluating it at the passed input points. The input points must be
50-
a LabelTensor.
52+
a LabelTensor. Based on the type of the input, there are different
53+
implementations of the condition. For more details, see
54+
:class:`~pina.condition.input_equation_condition.InputEquationCondition`
55+
.
5156
52-
4. By specifying only the data matrix; in such a case the model is
57+
4. By specifying only the input data; in such a case the model is
5358
trained with an unsupervised costum loss and uses the data in training.
5459
Additionaly conditioning variables can be passed, whenever the model
55-
has extra conditioning variable it depends on.
60+
has extra conditioning variable it depends on. Based on the type of the
61+
input, there are different implementations of the condition. For more
62+
details, see :class:`~pina.condition.data_condition.DataCondition`.
5663
5764
Example::
5865
59-
>>> from pina import Condition
60-
>>> condition = Condition(
61-
... input=input,
62-
... target=target
63-
... )
64-
>>> condition = Condition(
65-
... domain=location,
66-
... equation=equation
67-
... )
68-
>>> condition = Condition(
69-
... input=input,
70-
... equation=equation
71-
... )
72-
>>> condition = Condition(
73-
... input=data,
74-
... conditional_variables=conditional_variables
75-
... )
66+
>>> from pina import Condition
67+
>>> condition = Condition(
68+
... input=input,
69+
... target=target
70+
... )
71+
>>> condition = Condition(
72+
... domain=location,
73+
... equation=equation
74+
... )
75+
>>> condition = Condition(
76+
... input=input,
77+
... equation=equation
78+
... )
79+
>>> condition = Condition(
80+
... input=data,
81+
... conditional_variables=conditional_variables
82+
... )
83+
7684
"""
7785

7886
__slots__ = list(
@@ -86,24 +94,14 @@ class Condition:
8694

8795
def __new__(cls, *args, **kwargs):
8896
"""
89-
Create a new condition object based on the keyword arguments passed.
90-
91-
- `input` and `target`:
92-
:class:`~pina.condition.input_target_condition.InputTargetCondition`
93-
- `domain` and `equation`:
94-
:class:`~pina.condition.domain_equation_condition.
95-
DomainEquationCondition`
96-
- `input` and `equation`: :class:`~pina.condition.
97-
input_equation_condition.InputEquationCondition`
98-
- `input`: :class:`~pina.condition.data_condition.DataCondition`
99-
- `input` and `conditional_variables`:
100-
:class:`~pina.condition.data_condition.DataCondition`
101-
:return: A new condition instance belonging to the proper class.
102-
:rtype: InputTargetCondition | DomainEquationCondition |
103-
InputEquationCondition | DataCondition
104-
105-
:raises ValueError: No valid condition has been found.
97+
Check the input arguments and return the appropriate Condition object.
98+
99+
:raises ValueError: If no keyword arguments are passed.
100+
:raises ValueError: If the keyword arguments are invalid.
101+
:return: The appropriate Condition object.
102+
:rtype: ConditionInterface
106103
"""
104+
107105
if len(args) != 0:
108106
raise ValueError(
109107
"Condition takes only the following keyword "

pina/condition/condition_interface.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,61 @@
1111
class ConditionInterface(metaclass=ABCMeta):
1212
"""
1313
Abstract class which defines a common interface for all the conditions.
14+
It defined a common interface for all the conditions.
15+
1416
"""
1517

1618
def __init__(self):
19+
"""
20+
Initialize the ConditionInterface object.
21+
"""
22+
1723
self._problem = None
1824

1925
@property
2026
def problem(self):
2127
"""
2228
Return the problem to which the condition is associated.
2329
24-
:return: Problem to which the condition is associated.
30+
:return: Problem to which the condition is associated
2531
:rtype: pina.problem.AbstractProblem
2632
"""
27-
2833
return self._problem
2934

3035
@problem.setter
3136
def problem(self, value):
3237
"""
3338
Set the problem to which the condition is associated.
3439
35-
:param pina.problem.AbstractProblem value: Problem to which the
36-
condition is associated.
40+
:param pina.problem.abstract_problem.AbstractProblem value: Problem to
41+
which the condition is associated
3742
"""
38-
3943
self._problem = value
4044

4145
@staticmethod
4246
def _check_graph_list_consistency(data_list):
4347
"""
44-
Check if the list of :class:`~torch_geometric.data.Data` or
45-
class:`pina.graphGraph` objects is consistent.
46-
47-
:param data_list: List of graph type objects.
48-
:type data_list: Data | Graph | list[Data] | list[Graph]
49-
50-
:raises ValueError: Input data must be either Data
51-
or Graph objects.
52-
:raises ValueError: All elements in the list must have the same keys.
53-
:raises ValueError: Type mismatch in data tensors.
54-
:raises ValueError: Label mismatch in LabelTensors.
48+
Check the consistency of the list of Data/Graph objects. It performs
49+
the following checks:
50+
51+
1. All elements in the list must be of the same type (either Data or
52+
Graph).
53+
2. All elements in the list must have the same keys.
54+
3. The type of each tensor must be consistent across all elements in
55+
the list.
56+
4. If the tensor is a LabelTensor, the labels must be consistent across
57+
all elements in the list.
58+
59+
:param data_list: List of Data/Graph objects to check
60+
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
61+
62+
:raises ValueError: If the input types are invalid.
63+
:raises ValueError: If all elements in the list do not have the same
64+
keys.
65+
:raises ValueError: If the type of each tensor is not consistent across
66+
all elements in the list.
67+
:raises ValueError: If the labels of the LabelTensors are not consistent
68+
across all elements in the list.
5569
"""
5670

5771
# If the data is a Graph or Data object, return (do not need to check

pina/condition/data_condition.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
class DataCondition(ConditionInterface):
1313
"""
1414
Condition defined by input data and conditional variables. It can be used
15-
in unsupervised learning problems.
15+
in unsupervised learning problems. Based on the type of the input,
16+
different condition implementations are available:
17+
18+
- :class:`TensorDataCondition`: For :class:`torch.Tensor` or
19+
:class:`~pina.label_tensor.LabelTensor` input data.
20+
- :class:`GraphDataCondition`: For :class:`~pina.graph.Graph` or
21+
:class:`~torch_geometric.data.Data` input data.
1622
"""
1723

1824
__slots__ = ["input", "conditional_variables"]

pina/condition/input_equation_condition.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
class InputEquationCondition(ConditionInterface):
1414
"""
1515
Condition defined by input data and an equation. This condition can be
16-
used in a Physics Informed problems.
16+
used in a Physics Informed problems. Based on the type of the input,
17+
different condition implementations are available:
18+
19+
- :class:`InputTensorEquationCondition`: For
20+
:class:`~pina.label_tensor.LabelTensor` input data.
21+
- :class:`InputGraphEquationCondition`: For :class:`~pina.graph.Graph`
22+
input data.
1723
"""
1824

1925
__slots__ = ["input", "equation"]

pina/condition/input_target_condition.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,20 @@
1212
class InputTargetCondition(ConditionInterface):
1313
"""
1414
Condition defined by input and target data. This condition can be used in
15-
both supervised learning and Physics-informed problems.
15+
both supervised learning and Physics-informed problems. Based on the type of
16+
the input and target, different condition implementations are available:
17+
18+
- :class:`TensorInputTensorTargetCondition`: For :class:`torch.Tensor` or
19+
:class:`~pina.label_tensor.LabelTensor` input and target data.
20+
- :class:`TensorInputGraphTargetCondition`: For :class:`torch.Tensor` or
21+
:class:`~pina.label_tensor.LabelTensor` input and
22+
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
23+
target data.
24+
- :class:`GraphInputTensorTargetCondition`: For :class:`~pina.graph.Graph`
25+
or :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor`
26+
or :class:`~pina.label_tensor.LabelTensor` target data.
27+
- :class:`GraphInputGraphTargetCondition`: For :class:`~pina.graph.Graph` or
28+
:class:`~torch_geometric.data.Data` input and target data.
1629
"""
1730

1831
__slots__ = ["input", "target"]

0 commit comments

Comments
 (0)