Skip to content

Commit f26b667

Browse files
Improve conditions and refactor dataset classes (#475)
* Reimplement conditions * Refactor datasets and implement LabelBatch --------- Co-authored-by: Dario Coscia <[email protected]>
1 parent 9b08a1c commit f26b667

40 files changed

+932
-539
lines changed

.pylintrc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ logging-modules=logging
214214

215215

216216
[FORMAT]
217-
218217
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
219218
expected-line-ending-format=
220219

@@ -250,6 +249,8 @@ single-line-if-stmt=no
250249

251250

252251
[BASIC]
252+
# Allow redefinition of input builtins
253+
allowed-redefined-builtins=input
253254

254255
# Naming hint for argument names
255256
argument-name-hint=(([a-z][a-z0-9_]{2,30})|(_[a-z0-9_]*))$
@@ -401,7 +402,7 @@ max-returns=6
401402
max-statements=50
402403

403404
# Minimum number of public methods for a class (see R0903).
404-
min-public-methods=2
405+
min-public-methods=0
405406

406407

407408
[CLASSES]

code_formatter.sh

Lines changed: 12 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,19 @@
22

33
#######################################
44

5-
required_command="yapf unexpand"
6-
code_directories="pina tests"
5+
required_command="black"
6+
code_directories=("pina" "tests")
77

88
#######################################
99

10-
usage() {
11-
echo
12-
echo -e "\tUsage: $0 [files]"
13-
echo
14-
echo -e "\tIf not files are specified, script formats all ".py" files"
15-
echo -e "\tin code directories ($code_directories); otherwise, formats"
16-
echo -e "\tall given files"
17-
echo
18-
echo -e "\tRequired command: $required_command"
19-
echo
20-
exit 0
21-
}
22-
23-
24-
[[ $1 == "-h" ]] && usage
25-
2610
# Test for required program
27-
for comm in $required_command; do
28-
command -v $comm >/dev/null 2>&1 || {
29-
echo "I require $comm but it's not installed. Aborting." >&2;
30-
exit 1
31-
}
32-
done
33-
34-
# Find all python files in code directories
35-
python_files=""
36-
for dir in $code_directories; do
37-
python_files="$python_files $(find $dir -name '*.py')"
38-
done
39-
[[ $# != 0 ]] && python_files=$@
40-
41-
42-
# Here the important part: yapf format the files.
43-
for file in $python_files; do
44-
echo "Making beatiful $file..."
45-
[[ ! -f $file ]] && echo "$file does not exist; $0 -h for more info" && exit
46-
47-
yapf --style='{
48-
based_on_style: pep8,
49-
indent_width: 4,
50-
column_limit: 80
51-
}' -i $file
52-
done
11+
if ! command -v $required_command >/dev/null 2>&1; then
12+
echo "I require $required_command but it's not installed. Install dev dependencies."
13+
echo "Aborting." >&2
14+
exit 1
15+
fi
16+
17+
# Run black formatter
18+
for dir in "${code_directories[@]}"; do
19+
python -m black --line-length 80 "$dir"
20+
done

pina/collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,6 @@ def store_sample_domains(self):
7676
samples = self.problem.discretised_domains[condition.domain]
7777

7878
self.data_collections[condition_name] = {
79-
"input_points": samples,
79+
"input": samples,
8080
"equation": condition.equation,
8181
}

pina/condition/__init__.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,41 @@
1+
"""
2+
Module for conditions.
3+
"""
4+
15
__all__ = [
26
"Condition",
37
"ConditionInterface",
48
"DomainEquationCondition",
5-
"InputPointsEquationCondition",
6-
"InputOutputPointsCondition",
9+
"InputTargetCondition",
10+
"TensorInputTensorTargetCondition",
11+
"TensorInputGraphTargetCondition",
12+
"GraphInputTensorTargetCondition",
13+
"GraphInputGraphTargetCondition",
14+
"InputEquationCondition",
15+
"InputTensorEquationCondition",
16+
"InputGraphEquationCondition",
17+
"DataCondition",
18+
"GraphDataCondition",
19+
"TensorDataCondition",
720
]
821

922
from .condition_interface import ConditionInterface
23+
from .condition import Condition
1024
from .domain_equation_condition import DomainEquationCondition
11-
from .input_equation_condition import InputPointsEquationCondition
12-
from .input_output_condition import InputOutputPointsCondition
25+
from .input_target_condition import (
26+
InputTargetCondition,
27+
TensorInputTensorTargetCondition,
28+
TensorInputGraphTargetCondition,
29+
GraphInputTensorTargetCondition,
30+
GraphInputGraphTargetCondition,
31+
)
32+
from .input_equation_condition import (
33+
InputEquationCondition,
34+
InputTensorEquationCondition,
35+
InputGraphEquationCondition,
36+
)
37+
from .data_condition import (
38+
DataCondition,
39+
GraphDataCondition,
40+
TensorDataCondition,
41+
)

pina/condition/condition.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
11
"""Condition module."""
22

3-
from .domain_equation_condition import DomainEquationCondition
4-
from .input_equation_condition import InputPointsEquationCondition
5-
from .input_output_condition import InputOutputPointsCondition
6-
from .data_condition import DataConditionInterface
73
import warnings
4+
from .data_condition import DataCondition
5+
from .domain_equation_condition import DomainEquationCondition
6+
from .input_equation_condition import InputEquationCondition
7+
from .input_target_condition import InputTargetCondition
88
from ..utils import custom_warning_format
99

1010
# Set the custom format for warnings
1111
warnings.formatwarning = custom_warning_format
1212
warnings.filterwarnings("always", category=DeprecationWarning)
1313

1414

15+
def warning_function(new, old):
16+
"""Handle the deprecation warning.
17+
18+
:param new: Object to use instead of the old one.
19+
:type new: str
20+
:param old: Object to deprecate.
21+
:type old: str
22+
"""
23+
warnings.warn(
24+
f"'{old}' is deprecated and will be removed "
25+
f"in future versions. Please use '{new}' instead.",
26+
DeprecationWarning,
27+
)
28+
29+
1530
class Condition:
1631
"""
1732
The class ``Condition`` is used to represent the constraints (physical
@@ -40,16 +55,32 @@ class Condition:
4055
4156
Example::
4257
43-
>>> TODO
58+
>>> from pina import Condition
59+
>>> condition = Condition(
60+
... input=input,
61+
... target=target
62+
... )
63+
>>> condition = Condition(
64+
... domain=location,
65+
... equation=equation
66+
... )
67+
>>> condition = Condition(
68+
... input=input,
69+
... equation=equation
70+
... )
71+
>>> condition = Condition(
72+
... input=data,
73+
... conditional_variables=conditional_variables
74+
... )
4475
4576
"""
4677

4778
__slots__ = list(
4879
set(
49-
InputOutputPointsCondition.__slots__
50-
+ InputPointsEquationCondition.__slots__
80+
InputTargetCondition.__slots__
81+
+ InputEquationCondition.__slots__
5182
+ DomainEquationCondition.__slots__
52-
+ DataConditionInterface.__slots__
83+
+ DataCondition.__slots__
5384
)
5485
)
5586

@@ -62,25 +93,30 @@ def __new__(cls, *args, **kwargs):
6293
)
6394

6495
# back-compatibility 0.1
65-
if "location" in kwargs.keys():
96+
keys = list(kwargs.keys())
97+
if "location" in keys:
6698
kwargs["domain"] = kwargs.pop("location")
67-
warnings.warn(
68-
f"'location' is deprecated and will be removed "
69-
f"in future versions. Please use 'domain' instead.",
70-
DeprecationWarning,
71-
)
99+
warning_function(new="domain", old="location")
72100

73-
sorted_keys = sorted(kwargs.keys())
101+
if "input_points" in keys:
102+
kwargs["input"] = kwargs.pop("input_points")
103+
warning_function(new="input", old="input_points")
74104

75-
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
76-
return InputOutputPointsCondition(**kwargs)
77-
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
78-
return InputPointsEquationCondition(**kwargs)
79-
elif sorted_keys == sorted(DomainEquationCondition.__slots__):
105+
if "output_points" in keys:
106+
kwargs["target"] = kwargs.pop("output_points")
107+
warning_function(new="target", old="output_points")
108+
109+
sorted_keys = sorted(kwargs.keys())
110+
if sorted_keys == sorted(InputTargetCondition.__slots__):
111+
return InputTargetCondition(**kwargs)
112+
if sorted_keys == sorted(InputEquationCondition.__slots__):
113+
return InputEquationCondition(**kwargs)
114+
if sorted_keys == sorted(DomainEquationCondition.__slots__):
80115
return DomainEquationCondition(**kwargs)
81-
elif sorted_keys == sorted(DataConditionInterface.__slots__):
82-
return DataConditionInterface(**kwargs)
83-
elif sorted_keys == DataConditionInterface.__slots__[0]:
84-
return DataConditionInterface(**kwargs)
85-
else:
86-
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
116+
if (
117+
sorted_keys == sorted(DataCondition.__slots__)
118+
or sorted_keys[0] == DataCondition.__slots__[0]
119+
):
120+
return DataCondition(**kwargs)
121+
122+
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,84 @@
1+
"""
2+
Module that defines the ConditionInterface class.
3+
"""
4+
15
from abc import ABCMeta
6+
from torch_geometric.data import Data
7+
from ..label_tensor import LabelTensor
8+
from ..graph import Graph
29

310

411
class ConditionInterface(metaclass=ABCMeta):
12+
"""
13+
Abstract class which defines a common interface for all the conditions.
14+
"""
515

6-
condition_types = ["physics", "supervised", "unsupervised"]
7-
8-
def __init__(self, *args, **kwargs):
9-
self._condition_type = None
16+
def __init__(self):
1017
self._problem = None
1118

1219
@property
1320
def problem(self):
21+
"""
22+
Return the problem to which the condition is associated.
23+
24+
:return: Problem to which the condition is associated
25+
:rtype: pina.problem.AbstractProblem
26+
"""
1427
return self._problem
1528

1629
@problem.setter
1730
def problem(self, value):
1831
self._problem = value
1932

20-
@property
21-
def condition_type(self):
22-
return self._condition_type
23-
24-
@condition_type.setter
25-
def condition_type(self, values):
26-
if not isinstance(values, (list, tuple)):
27-
values = [values]
28-
for value in values:
29-
if value not in ConditionInterface.condition_types:
33+
@staticmethod
34+
def _check_graph_list_consistency(data_list):
35+
36+
# If the data is a Graph or Data object, return (do not need to check
37+
# anything)
38+
if isinstance(data_list, (Graph, Data)):
39+
return
40+
41+
# check all elements in the list are of the same type
42+
if not all(isinstance(i, (Graph, Data)) for i in data_list):
43+
raise ValueError(
44+
"Invalid input types. "
45+
"Please provide either Data or Graph objects."
46+
)
47+
data = data_list[0]
48+
# Store the keys of the first element in the list
49+
keys = sorted(list(data.keys()))
50+
51+
# Store the type of each tensor inside first element Data/Graph object
52+
data_types = {name: tensor.__class__ for name, tensor in data.items()}
53+
54+
# Store the labels of each LabelTensor inside first element Data/Graph
55+
# object
56+
labels = {
57+
name: tensor.labels
58+
for name, tensor in data.items()
59+
if isinstance(tensor, LabelTensor)
60+
}
61+
62+
# Iterate over the list of Data/Graph objects
63+
for data in data_list[1:]:
64+
# Check if the keys of the current element are the same as the first
65+
# element
66+
if sorted(list(data.keys())) != keys:
3067
raise ValueError(
31-
"Unavailable type of condition, expected one of"
32-
f" {ConditionInterface.condition_types}."
68+
"All elements in the list must have the same keys."
3369
)
34-
self._condition_type = values
70+
for name, tensor in data.items():
71+
# Check if the type of each tensor inside the current element
72+
# is the same as the first element
73+
if tensor.__class__ is not data_types[name]:
74+
raise ValueError(
75+
f"Data {name} must be a {data_types[name]}, got "
76+
f"{tensor.__class__}"
77+
)
78+
# If the tensor is a LabelTensor, check if the labels are the
79+
# same as the first element
80+
if isinstance(tensor, LabelTensor):
81+
if tensor.labels != labels[name]:
82+
raise ValueError(
83+
"LabelTensor must have the same labels"
84+
)

0 commit comments

Comments
 (0)