Skip to content

Commit 8145688

Browse files
committed
Add consistency check for graph data
1 parent 20f67a0 commit 8145688

File tree

5 files changed

+116
-1
lines changed

5 files changed

+116
-1
lines changed
Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,79 @@
1+
"""
2+
Module that defines the ConditionInterface class.
3+
"""
4+
15
from abc import ABCMeta
6+
from torch_geometric.data import Data
7+
from ..label_tensor import LabelTensor
8+
from ..graph import Graph
29

310

411
class ConditionInterface(metaclass=ABCMeta):
12+
"""
13+
Abstract class which defines a common interface for all the conditions.
14+
"""
515

6-
def __init__(self, *args, **kwargs):
16+
def __init__(self):
717
self._problem = None
818

919
@property
1020
def problem(self):
21+
"""
22+
Return the problem to which the condition is associated.
23+
24+
:return: Problem to which the condition is associated
25+
:rtype: pina.problem.AbstractProblem
26+
"""
1127
return self._problem
1228

1329
@problem.setter
1430
def problem(self, value):
1531
self._problem = value
32+
33+
@staticmethod
34+
def _check_graph_list_consistency(data_list):
35+
36+
# If the data is a Graph or Data object, return (do not need to check
37+
# anything)
38+
if isinstance(data_list, (Graph, Data)):
39+
return
40+
data = data_list[0]
41+
# Store the keys of the first element in the list
42+
keys = sorted(list(data.keys()))
43+
44+
# Store the type of each tensor inside first element Data/Graph object
45+
data_types = {name: tensor.__class__ for name, tensor in data.items()}
46+
47+
# Store the labels of each LabelTensor inside first element Data/Graph
48+
# object
49+
labels = {
50+
name: tensor.labels
51+
for name, tensor in data.items()
52+
if isinstance(tensor, LabelTensor)
53+
}
54+
# Iterate over the list of Data/Graph objects
55+
print(data_types)
56+
for data in data_list[1:]:
57+
# Check if the keys of the current element are the same as the first
58+
# element
59+
if sorted(list(data.keys())) != keys:
60+
raise ValueError(
61+
"All elements in the list must have the same keys."
62+
)
63+
64+
for name, tensor in data.items():
65+
# Check if the type of each tensor inside the current element
66+
# is the same as the first element
67+
if tensor.__class__ is not data_types[name]:
68+
raise ValueError(
69+
f"Data {name} must be a {data_types[name]}, got "
70+
f"{tensor.__class__}"
71+
)
72+
73+
# If the tensor is a LabelTensor, check if the labels are the
74+
# same as the first element
75+
if isinstance(tensor, LabelTensor):
76+
if tensor.labels != labels[name]:
77+
raise ValueError(
78+
f"Data {name} must have the same labels"
79+
)

pina/condition/data_condition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,7 @@ class GraphDataCondition(DataCondition):
8888
"""
8989
DataCondition for Graph/Data input data
9090
"""
91+
92+
def __init__(self, input, conditional_variables=None):
93+
self._check_graph_list_consistency(input)
94+
super().__init__(input, conditional_variables)

pina/condition/input_equation_condition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,7 @@ class InputGraphEquationCondition(InputEquationCondition):
7878
"""
7979
InputEquationCondition subclass for Graph input data.
8080
"""
81+
82+
def __init__(self, input, equation):
83+
super().__init__(input, equation)
84+
self._check_graph_list_consistency(input)

pina/condition/input_target_condition.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ class TensorInputGraphTargetCondition(InputTargetCondition):
9999
data.
100100
"""
101101

102+
def __init__(self, input, target):
103+
self._check_graph_list_consistency(target)
104+
super().__init__(input, target)
105+
102106
@staticmethod
103107
def _check_input_target_consistency(input, target):
104108
if isinstance(target, (Graph, Data)):
@@ -115,6 +119,10 @@ class GraphInputTensorTargetCondition(InputTargetCondition):
115119
data.
116120
"""
117121

122+
def __init__(self, input, target):
123+
self._check_graph_list_consistency(input)
124+
super().__init__(input, target)
125+
118126
@staticmethod
119127
def _check_input_target_consistency(input, target):
120128
if isinstance(input, (Graph, Data)):
@@ -130,6 +138,11 @@ class GraphInputGraphTargetCondition(InputTargetCondition):
130138
InputTargetCondition subclass for Graph/Data input and target data.
131139
"""
132140

141+
def __init__(self, input, target):
142+
self._check_graph_list_consistency(input)
143+
self._check_graph_list_consistency(target)
144+
super().__init__(input, target)
145+
133146
@staticmethod
134147
def _check_input_target_consistency(input, target):
135148
if isinstance(input, list) and isinstance(target, list):

tests/test_condition.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,26 @@
4848
for x_, pos_ in zip(x, pos)
4949
]
5050

51+
x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"])
52+
pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"])
53+
radius = 0.1
54+
input_graph_lt = [
55+
RadiusGraph(
56+
x=x[i],
57+
pos=pos[i],
58+
radius=radius,
59+
)
60+
for i in range(len(x))
61+
]
62+
target_graph_lt = [
63+
RadiusGraph(
64+
x=x[i],
65+
pos=pos[i],
66+
radius=radius,
67+
)
68+
for i in range(len(x))
69+
]
70+
5171
input_single_graph = input_graph[0]
5272
target_single_graph = target_graph[0]
5373

@@ -80,6 +100,16 @@ def test_init_input_target():
80100
with pytest.raises(ValueError):
81101
Condition(input=example_domain, target=example_domain)
82102

103+
# Test wrong graph condition initialisation
104+
input = [input_graph[0], input_graph_lt[0]]
105+
target = [target_graph[0], target_graph_lt[0]]
106+
with pytest.raises(ValueError):
107+
Condition(input=input, target=target)
108+
109+
input_graph_lt[0].x.labels = ["a", "b"]
110+
with pytest.raises(ValueError):
111+
Condition(input=input_graph_lt, target=target_graph_lt)
112+
83113

84114
def test_init_domain_equation():
85115
cond = Condition(domain=example_domain, equation=FixedValue(0.0))

0 commit comments

Comments
 (0)