Skip to content

Commit c59ef8f

Browse files
committed
Adding logic
1 parent 7b4bdfe commit c59ef8f

File tree

6 files changed

+78
-71
lines changed

6 files changed

+78
-71
lines changed

pina/condition/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Module for conditions.
3+
"""
4+
15
__all__ = [
26
"Condition",
37
"ConditionInterface",
@@ -16,6 +20,7 @@
1620
]
1721

1822
from .condition_interface import ConditionInterface
23+
from .condition import Condition
1924
from .domain_equation_condition import DomainEquationCondition
2025
from .input_target_condition import (
2126
InputTargetCondition,

pina/condition/condition.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Condition module."""
22

3+
import warnings
4+
from .data_condition import DataCondition
35
from .domain_equation_condition import DomainEquationCondition
46
from .input_equation_condition import InputEquationCondition
57
from .input_target_condition import InputTargetCondition
6-
from .data_condition import DataCondition
7-
import warnings
88
from ..utils import custom_warning_format
99

1010
# Set the custom format for warnings
@@ -13,6 +13,13 @@
1313

1414

1515
def warning_function(new, old):
16+
"""Handle the deprecation warning.
17+
18+
:param new: Object to use instead of the old one.
19+
:type new: str
20+
:param old: Object to deprecate.
21+
:type old: str
22+
"""
1623
warnings.warn(
1724
f"'{old}' is deprecated and will be removed "
1825
f"in future versions. Please use '{new}' instead.",
@@ -48,7 +55,23 @@ class Condition:
4855
4956
Example::
5057
51-
>>> TODO
58+
>>> from pina import Condition
59+
>>> condition = Condition(
60+
... input=input,
61+
... target=target
62+
... )
63+
>>> condition = Condition(
64+
... domain=location,
65+
... equation=equation
66+
... )
67+
>>> condition = Condition(
68+
... input=input,
69+
... equation=equation
70+
... )
71+
>>> condition = Condition(
72+
... input=data,
73+
... conditional_variables=conditional_variables
74+
... )
5275
5376
"""
5477

pina/condition/data_condition.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
import torch
66
from torch_geometric.data import Data
7-
from . import ConditionInterface
7+
from .condition_interface import ConditionInterface
88
from ..label_tensor import LabelTensor
99
from ..graph import Graph
10-
from ..utils import check_consistency
1110

1211

1312
class DataCondition(ConditionInterface):
@@ -36,20 +35,21 @@ def __new__(cls, input, conditional_variables=None):
3635
:return: DataCondition subclass
3736
:rtype: TensorDataCondition or GraphDataCondition
3837
"""
39-
if cls == DataCondition and isinstance(
40-
input, (torch.Tensor, LabelTensor)
41-
):
38+
if cls != DataCondition:
39+
return super().__new__(cls)
40+
if isinstance(input, (torch.Tensor, LabelTensor)):
4241
subclass = TensorDataCondition
4342
return subclass.__new__(subclass, input, conditional_variables)
4443

45-
elif cls == DataCondition and isinstance(
46-
input, (Graph, Data, list, tuple)
47-
):
44+
if isinstance(input, (Graph, Data, list, tuple)):
4845
cls._check_graph_list_consistency(input)
4946
subclass = GraphDataCondition
5047
return subclass.__new__(subclass, input, conditional_variables)
5148

52-
return super().__new__(cls)
49+
raise ValueError(
50+
"Invalid input types. "
51+
"Please provide either Data or Graph objects."
52+
)
5353

5454
def __init__(self, input, conditional_variables=None):
5555
"""
@@ -67,17 +67,6 @@ def __init__(self, input, conditional_variables=None):
6767
self.input = input
6868
self.conditional_variables = conditional_variables
6969

70-
def __setattr__(self, key, value):
71-
if key == "input":
72-
check_consistency(value, self._avail_input_cls)
73-
DataCondition.__dict__[key].__set__(self, value)
74-
elif key == "conditional_variables":
75-
if value is not None:
76-
check_consistency(value, self._avail_conditional_variables_cls)
77-
DataCondition.__dict__[key].__set__(self, value)
78-
elif key in ("_problem"):
79-
super().__setattr__(key, value)
80-
8170

8271
class TensorDataCondition(DataCondition):
8372
"""

pina/condition/domain_equation_condition.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
DomainEquationCondition class definition.
33
"""
44

5-
import torch
6-
75
from .condition_interface import ConditionInterface
86
from ..utils import check_consistency
97
from ..domain import DomainInterface

pina/condition/input_equation_condition.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Module to define InputEquationCondition class and its subclasses.
33
"""
44

5-
import torch
65
from torch_geometric.data import Data
76
from .condition_interface import ConditionInterface
87
from ..label_tensor import LabelTensor
@@ -33,17 +32,25 @@ def __new__(cls, input, equation):
3332
:return: InputEquationCondition subclass
3433
:rtype: InputTensorEquationCondition or InputGraphEquationCondition
3534
"""
36-
if cls == InputEquationCondition and isinstance(input, LabelTensor):
37-
subclass = InputTensorEquationCondition
38-
return subclass.__new__(subclass, input, equation)
39-
elif cls == InputEquationCondition and isinstance(
40-
input, (Graph, Data, list, tuple)
41-
):
42-
cls._check_graph_list_consistency(input)
35+
36+
# If the class is already a subclass, return the instance
37+
if cls != InputEquationCondition:
38+
return super().__new__(cls)
39+
40+
# Instanciate the correct subclass
41+
if isinstance(input, (Graph, Data, list, tuple)):
4342
subclass = InputGraphEquationCondition
43+
cls._check_graph_list_consistency(input)
4444
subclass._check_label_tensor(input)
4545
return subclass.__new__(subclass, input, equation)
46-
return super().__new__(cls)
46+
if isinstance(input, LabelTensor):
47+
subclass = InputTensorEquationCondition
48+
return subclass.__new__(subclass, input, equation)
49+
50+
# If the input is not a LabelTensor or a Graph object raise an error
51+
raise ValueError(
52+
"The input data object must be a LabelTensor or a Graph object."
53+
)
4754

4855
def __init__(self, input, equation):
4956
"""

pina/condition/input_target_condition.py

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
import torch
66
from torch_geometric.data import Data
7-
from .condition_interface import ConditionInterface
87
from ..label_tensor import LabelTensor
98
from ..graph import Graph
10-
from ..utils import check_consistency
9+
from .condition_interface import ConditionInterface
1110

1211

1312
class InputTargetCondition(ConditionInterface):
@@ -34,43 +33,41 @@ def __new__(cls, input, target):
3433
TensorInputGraphTargetCondition or GraphInputTensorTargetCondition
3534
or GraphInputGraphTargetCondition
3635
"""
37-
if (
38-
cls == InputTargetCondition
39-
and isinstance(input, (torch.Tensor, LabelTensor))
40-
and isinstance(target, (torch.Tensor, LabelTensor))
36+
if cls != InputTargetCondition:
37+
return super().__new__(cls)
38+
39+
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
40+
target, (torch.Tensor, LabelTensor)
4141
):
4242
subclass = TensorInputTensorTargetCondition
4343
return subclass.__new__(subclass, input, target)
44-
45-
elif (
46-
cls == InputTargetCondition
47-
and isinstance(input, (torch.Tensor, LabelTensor))
48-
and isinstance(target, (Graph, Data, list, tuple))
44+
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
45+
target, (Graph, Data, list, tuple)
4946
):
5047
cls._check_graph_list_consistency(target)
5148
subclass = TensorInputGraphTargetCondition
5249
return subclass.__new__(subclass, input, target)
5350

54-
elif (
55-
cls == InputTargetCondition
56-
and isinstance(input, (Graph, Data, list, tuple))
57-
and isinstance(target, (torch.Tensor, LabelTensor))
51+
if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
52+
target, (torch.Tensor, LabelTensor)
5853
):
5954
cls._check_graph_list_consistency(input)
6055
subclass = GraphInputTensorTargetCondition
6156
return subclass.__new__(subclass, input, target)
6257

63-
elif (
64-
cls == InputTargetCondition
65-
and isinstance(input, (Graph, Data, list, tuple))
66-
and isinstance(target, (Graph, Data, list, tuple))
58+
if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
59+
target, (Graph, Data, list, tuple)
6760
):
6861
cls._check_graph_list_consistency(input)
6962
cls._check_graph_list_consistency(target)
7063
subclass = GraphInputGraphTargetCondition
7164
return subclass.__new__(subclass, input, target)
7265

73-
return super().__new__(cls)
66+
raise ValueError(
67+
"Invalid input/target types. "
68+
"Please provide either Data, Graph, LabelTensor or torch.Tensor "
69+
"objects."
70+
)
7471

7572
def __init__(self, input, target):
7673
"""
@@ -82,28 +79,16 @@ def __init__(self, input, target):
8279
:type target: torch.Tensor or Graph or Data
8380
"""
8481
super().__init__()
85-
if isinstance(input, (list, tuple)) or isinstance(
86-
target, (list, tuple)
87-
):
88-
self._check_input_target_len(input, target)
82+
self._check_input_target_len(input, target)
8983
self.input = input
9084
self.target = target
9185

92-
def __setattr__(self, key, value):
93-
if key == "input":
94-
check_consistency(value, (torch.Tensor, LabelTensor, Data, Graph))
95-
InputTargetCondition.__dict__[key].__set__(self, value)
96-
elif key == "target":
97-
if value is not None:
98-
check_consistency(
99-
value, (torch.Tensor, LabelTensor, Data, Graph)
100-
)
101-
InputTargetCondition.__dict__[key].__set__(self, value)
102-
elif key in ("_problem"):
103-
super().__setattr__(key, value)
104-
10586
@staticmethod
10687
def _check_input_target_len(input, target):
88+
if isinstance(input, (Graph, Data)) or isinstance(
89+
target, (Graph, Data)
90+
):
91+
return
10792
if len(input) != len(target):
10893
raise ValueError(
10994
"The input and target lists must have the same length."

0 commit comments

Comments
 (0)