Skip to content

Commit d6d778d

Browse files
FilippoOlivodario-coscia
authored andcommitted
Additional fix in condition
1 parent 0181851 commit d6d778d

File tree

5 files changed

+34
-40
lines changed

5 files changed

+34
-40
lines changed

pina/condition/condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ def __new__(cls, *args, **kwargs):
9898
- `input`: :class:`~pina.condition.data_condition.DataCondition`
9999
- `input` and `conditional_variables`:
100100
:class:`~pina.condition.data_condition.DataCondition`
101-
102-
:raises ValueError: No valid condition has been found.
103101
:return: A new condition instance belonging to the proper class.
104102
:rtype: InputTargetCondition | DomainEquationCondition |
105103
InputEquationCondition | DataCondition
104+
105+
:raises ValueError: No valid condition has been found.
106106
"""
107107
if len(args) != 0:
108108
raise ValueError(

pina/condition/data_condition.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111

1212
class DataCondition(ConditionInterface):
1313
"""
14-
This condition must be used every time a Unsupervised Loss is needed in
15-
the Solver. The `conditional_variable` can be passed as extra-input when
16-
the model learns a conditional distribution.
14+
Condition defined by input data and conditional variables. It can be used
15+
in unsupervised learning problems.
1716
"""
1817

1918
__slots__ = ["input", "conditional_variables"]
@@ -22,24 +21,23 @@ class DataCondition(ConditionInterface):
2221

2322
def __new__(cls, input, conditional_variables=None):
2423
"""
25-
Instantiate the appropriate subclass of DataCondition based on the
26-
types of input data.
24+
Instantiate the appropriate subclass of :class:`DataCondition` based on
25+
the type of `input`.
2726
2827
:param input: Input data for the condition.
2928
:type input: torch.Tensor | LabelTensor | Graph |
3029
Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]
3130
:param conditional_variables: Conditional variables for the condition.
32-
:type conditional_variables: torch.Tensor | LabelTensor
31+
:type conditional_variables: torch.Tensor | LabelTensor, optional
3332
:return: Subclass of DataCondition.
3433
:rtype: pina.condition.data_condition.TensorDataCondition |
3534
pina.condition.data_condition.GraphDataCondition
3635
3736
:raises ValueError: If input is not of type :class:`torch.Tensor`,
3837
:class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`,
3938
or :class:`~torch_geometric.data.Data`.
40-
41-
4239
"""
40+
4341
if cls != DataCondition:
4442
return super().__new__(cls)
4543
if isinstance(input, (torch.Tensor, LabelTensor)):
@@ -69,8 +67,8 @@ def __init__(self, input, conditional_variables=None):
6967
7068
.. note::
7169
If either `input` is composed by a list of
72-
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
73-
objects, all elements must have the same structure (keys and data
70+
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`,
71+
all elements must have the same structure (keys and data
7472
types)
7573
"""
7674

pina/condition/domain_equation_condition.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010

1111
class DomainEquationCondition(ConditionInterface):
1212
"""
13-
Condition for domain/equation data. This condition must be used every
14-
time a Physics Informed Loss is needed in the Solver.
13+
Condition defined by a domain and an equation. It can be used in Physics
14+
Informed problems. Before using this condition, make sure that input data
15+
are correctly sampled from the domain.
1516
"""
1617

1718
__slots__ = ["domain", "equation"]
1819

1920
def __init__(self, domain, equation):
2021
"""
21-
Initialize the object by storing the domain and equation.
22+
Initialise the object by storing the domain and equation.
2223
2324
:param DomainInterface domain: Domain object containing the domain data.
2425
:param EquationInterface equation: Equation object containing the

pina/condition/input_equation_condition.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
class InputEquationCondition(ConditionInterface):
1414
"""
15-
Condition for input/equation data. This condition must be used every
16-
time a Physics Informed Loss is needed in the Solver.
15+
Condition defined by input data and an equation. This condition can be
16+
used in a Physics Informed problems.
1717
"""
1818

1919
__slots__ = ["input", "equation"]
@@ -22,10 +22,10 @@ class InputEquationCondition(ConditionInterface):
2222

2323
def __new__(cls, input, equation):
2424
"""
25-
Instantiate the appropriate subclass of InputEquationCondition based on
26-
the type of input data.
25+
Instantiate the appropriate subclass of :class:`InputEquationCondition`
26+
based on the type of `input`.
2727
28-
:param input: Input data. It can be a LabelTensor or a Graph object.
28+
:param input: Input data for the condition.
2929
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
3030
:param EquationInterface equation: Equation object containing the
3131
equation function.

pina/condition/input_target_condition.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
class InputTargetCondition(ConditionInterface):
1313
"""
14-
Condition for domain/equation data. This condition must be used every
15-
time a Physics Informed or a Supervised Loss is needed in the Solver.
14+
Condition defined by input and target data. This condition can be used in
15+
both supervised learning and Physics-informed problems.
1616
"""
1717

1818
__slots__ = ["input", "target"]
@@ -25,15 +25,11 @@ def __new__(cls, input, target):
2525
the types of input and target data.
2626
2727
:param input: Input data for the condition.
28-
:type input: torch.Tensor | LabelTensor | Graph |
29-
torch_geometric.data.Data | list[Graph] |
30-
list[torch_geometric.data.Data] | tuple[Graph] |
31-
tuple[torch_geometric.data.Data]
28+
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
29+
list[Data] | tuple[Graph] | tuple[Data]
3230
:param target: Target data for the condition.
33-
:type target: torch.Tensor | LabelTensor | Graph |
34-
torch_geometric.data.Data | list[Graph] |
35-
list[torch_geometric.data.Data] | tuple[Graph] |
36-
tuple[torch_geometric.data.Data]
31+
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
32+
list[Data] | tuple[Graph] | tuple[Data]
3733
:return: Subclass of InputTargetCondition
3834
:rtype: pina.condition.input_target_condition.
3935
TensorInputTensorTargetCondition |
@@ -43,7 +39,7 @@ def __new__(cls, input, target):
4339
GraphInputTensorTargetCondition |
4440
pina.condition.input_target_condition.GraphInputGraphTargetCondition
4541
46-
:raises ValueError: If input and or target are not of type
42+
:raises ValueError: If `input` and/or `target` are not of type
4743
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
4844
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
4945
"""
@@ -85,12 +81,11 @@ def __new__(cls, input, target):
8581

8682
def __init__(self, input, target):
8783
"""
88-
Initialize the object storing the input and target data.
84+
Initialize the object by storing the `input` and `target` data.
8985
9086
:param input: Input data for the condition.
9187
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
92-
list[Data] | tuple[Graph] |
93-
tuple[Data]
88+
list[Data] | tuple[Graph] | tuple[Data]
9489
:param target: Target data for the condition.
9590
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
9691
list[Data] | tuple[Graph] | tuple[Data]
@@ -122,29 +117,29 @@ def _check_input_target_len(input, target):
122117
class TensorInputTensorTargetCondition(InputTargetCondition):
123118
"""
124119
InputTargetCondition subclass for :class:`torch.Tensor` or
125-
:class:`~pina.label_tensor.LabelTensor` input and target data.
120+
:class:`~pina.label_tensor.LabelTensor` `input` and `target` data.
126121
"""
127122

128123

129124
class TensorInputGraphTargetCondition(InputTargetCondition):
130125
"""
131126
InputTargetCondition subclass for :class:`torch.Tensor` or
132-
:class:`~pina.label_tensor.LabelTensor` input and
133-
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` target
127+
:class:`~pina.label_tensor.LabelTensor` `input` and
128+
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target`
134129
data.
135130
"""
136131

137132

138133
class GraphInputTensorTargetCondition(InputTargetCondition):
139134
"""
140135
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
141-
:class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` or
142-
:class:`~pina.label_tensor.LabelTensor` target data.
136+
:class:`~torch_geometric.data.Data` `input` and :class:`torch.Tensor` or
137+
:class:`~pina.label_tensor.LabelTensor` `target` data.
143138
"""
144139

145140

146141
class GraphInputGraphTargetCondition(InputTargetCondition):
147142
"""
148143
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
149-
:class:`~torch_geometric.data.Data` input and target data.
144+
:class:`~torch_geometric.data.Data` `input` and `target` data.
150145
"""

0 commit comments

Comments
 (0)