Skip to content

Commit 9bfd926

Browse files
committed
Implement graph conditions and enhance Graph class
1 parent 1bbba45 commit 9bfd926

17 files changed

+816
-371
lines changed

pina/collector.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ def store_fixed_data(self):
5656
# get data
5757
keys = condition.__slots__
5858
values = [getattr(condition, name) for name in keys]
59-
values = [
60-
value.data if isinstance(value, Graph) else value
61-
for value in values
62-
]
6359
self.data_collections[condition_name] = dict(zip(keys, values))
6460
# condition now is ready
6561
self._is_conditions_ready[condition_name] = True

pina/condition/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
"DomainEquationCondition",
55
"InputPointsEquationCondition",
66
"InputOutputPointsCondition",
7+
"GraphInputOutputCondition",
8+
"GraphDataCondition",
9+
"GraphInputEquationCondition",
710
]
811

912
from .condition_interface import ConditionInterface
1013
from .domain_equation_condition import DomainEquationCondition
1114
from .input_equation_condition import InputPointsEquationCondition
1215
from .input_output_condition import InputOutputPointsCondition
16+
from .graph_condition import GraphInputOutputCondition
17+
from .graph_condition import GraphDataCondition
18+
from .graph_condition import GraphInputEquationCondition

pina/condition/condition.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from .input_equation_condition import InputPointsEquationCondition
55
from .input_output_condition import InputOutputPointsCondition
66
from .data_condition import DataConditionInterface
7+
from .graph_condition import (
8+
GraphInputOutputCondition,
9+
GraphInputEquationCondition,
10+
)
711
import warnings
812
from ..utils import custom_warning_format
913

@@ -82,5 +86,9 @@ def __new__(cls, *args, **kwargs):
8286
return DataConditionInterface(**kwargs)
8387
elif sorted_keys == DataConditionInterface.__slots__[0]:
8488
return DataConditionInterface(**kwargs)
89+
elif sorted_keys == sorted(GraphInputOutputCondition.__slots__):
90+
return GraphInputOutputCondition(**kwargs)
91+
elif sorted_keys == sorted(GraphInputEquationCondition.__slots__):
92+
return GraphInputEquationCondition(**kwargs)
8593
else:
8694
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

pina/condition/condition_interface.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33

44
class ConditionInterface(metaclass=ABCMeta):
55

6-
condition_types = ["physics", "supervised", "unsupervised"]
7-
86
def __init__(self, *args, **kwargs):
9-
self._condition_type = None
107
self._problem = None
118

129
@property
@@ -16,19 +13,3 @@ def problem(self):
1613
@problem.setter
1714
def problem(self, value):
1815
self._problem = value
19-
20-
@property
21-
def condition_type(self):
22-
return self._condition_type
23-
24-
@condition_type.setter
25-
def condition_type(self, values):
26-
if not isinstance(values, (list, tuple)):
27-
values = [values]
28-
for value in values:
29-
if value not in ConditionInterface.condition_types:
30-
raise ValueError(
31-
"Unavailable type of condition, expected one of"
32-
f" {ConditionInterface.condition_types}."
33-
)
34-
self._condition_type = values

pina/condition/graph_condition.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from .condition_interface import ConditionInterface
2+
from ..graph import Graph
3+
from ..utils import check_consistency
4+
from torch_geometric.data import Data
5+
from ..equation.equation_interface import EquationInterface
6+
7+
8+
class GraphCondition(ConditionInterface):
9+
"""
10+
TODO
11+
"""
12+
13+
__slots__ = ["graph"]
14+
15+
def __new__(cls, graph):
16+
"""
17+
TODO : add docstring
18+
"""
19+
check_consistency(graph, (Graph, Data))
20+
graph = [graph] if isinstance(graph, Data) else graph
21+
22+
if all(g.y is not None for g in graph):
23+
return super().__new__(GraphInputOutputCondition)
24+
else:
25+
return super().__new__(GraphDataCondition)
26+
27+
def __init__(self, graph):
28+
29+
super().__init__()
30+
self.graph = graph
31+
32+
def __setattr__(self, key, value):
33+
if key == "graph":
34+
check_consistency(value, (Graph, Data))
35+
GraphCondition.__dict__[key].__set__(self, value)
36+
elif key in ("_problem", "_condition_type"):
37+
super().__setattr__(key, value)
38+
39+
40+
class GraphInputEquationCondition(ConditionInterface):
41+
42+
__slots__ = ["graph", "equation"]
43+
44+
def __init__(self, graph, equation):
45+
super().__init__()
46+
self.graph = graph
47+
self.equation = equation
48+
49+
def __setattr__(self, key, value):
50+
if key == "graph":
51+
check_consistency(value, (Graph, Data))
52+
GraphInputEquationCondition.__dict__[key].__set__(self, value)
53+
elif key == "equation":
54+
check_consistency(value, (EquationInterface))
55+
GraphInputEquationCondition.__dict__[key].__set__(self, value)
56+
elif key in ("_problem", "_condition_type"):
57+
super().__setattr__(key, value)
58+
59+
60+
# The split between GraphInputOutputCondition and GraphDataCondition
61+
# distinguishes different types of graph conditions passed to problems.
62+
# This separation simplifies consistency checks during problem creation.
63+
class GraphDataCondition(GraphCondition):
64+
pass
65+
66+
67+
class GraphInputOutputCondition(GraphCondition):
68+
pass

pina/data/data_module.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -84,39 +84,39 @@ def _collate_standard_dataloader(self, batch):
8484
conditions_names = batch[0].keys()
8585
# Condition names
8686
for condition_name in conditions_names:
87-
single_cond_dict = {}
8887
condition_args = batch[0][condition_name].keys()
89-
for arg in condition_args:
90-
data_list = [
91-
batch[idx][condition_name][arg]
92-
for idx in range(
93-
min(
94-
len(batch),
95-
self.max_conditions_lengths[condition_name],
96-
)
97-
)
98-
]
99-
single_cond_dict[arg] = self._collate(data_list)
100-
101-
batch_dict[condition_name] = single_cond_dict
88+
batch_dict[condition_name] = self._collate(
89+
condition_args, condition_name, batch
90+
)
10291
return batch_dict
10392

104-
@staticmethod
105-
def _collate_tensor_dataset(data_list):
106-
if isinstance(data_list[0], LabelTensor):
107-
return LabelTensor.stack(data_list)
108-
if isinstance(data_list[0], torch.Tensor):
109-
return torch.stack(data_list)
110-
raise RuntimeError("Data must be Tensors or LabelTensor ")
111-
112-
def _collate_graph_dataset(self, data_list):
113-
if isinstance(data_list[0], LabelTensor):
114-
return LabelTensor.cat(data_list)
115-
if isinstance(data_list[0], torch.Tensor):
116-
return torch.cat(data_list)
117-
if isinstance(data_list[0], Data):
118-
return self.dataset.create_graph_batch(data_list)
119-
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
93+
def _collate_tensor_dataset(self, condition_args, condition_name, batch):
94+
to_return_dict = {}
95+
for arg in condition_args:
96+
data_list = [
97+
batch[idx][condition_name][arg]
98+
for idx in range(
99+
min(len(batch), self.max_conditions_lengths[condition_name])
100+
)
101+
]
102+
if isinstance(data_list[0], LabelTensor):
103+
data = LabelTensor.stack(data_list)
104+
elif isinstance(data_list[0], torch.Tensor):
105+
data = torch.stack(data_list)
106+
107+
to_return_dict[arg] = data
108+
return to_return_dict
109+
110+
def _collate_graph_dataset(self, condition_args, condition_name, batch):
111+
data_list = [
112+
batch[idx][condition_name]
113+
for idx in range(
114+
min(len(batch), self.max_conditions_lengths[condition_name])
115+
)
116+
]
117+
return self.dataset._divide_batch(
118+
batch=self.dataset.create_graph_batch(data_list)
119+
)
120120

121121
def __call__(self, batch):
122122
return self.callable_function(batch)
@@ -285,7 +285,7 @@ def setup(self, stage=None):
285285

286286
@staticmethod
287287
def _split_condition(condition_dict, splits_dict):
288-
len_condition = len(condition_dict["input_points"])
288+
len_condition = len(list(condition_dict.values())[0])
289289

290290
lengths = [
291291
int(len_condition * length) for length in splits_dict.values()
@@ -308,7 +308,7 @@ def _split_condition(condition_dict, splits_dict):
308308
if k != "equation"
309309
# Equations are NEVER dataloaded
310310
}
311-
if offset + stage_len >= len_condition:
311+
if offset + stage_len > len_condition:
312312
offset = len_condition - 1
313313
continue
314314
offset += stage_len
@@ -343,7 +343,7 @@ def _apply_shuffle(condition_dict, len_data):
343343
condition_name,
344344
condition_dict,
345345
) in collector.data_collections.items():
346-
len_data = len(condition_dict["input_points"])
346+
len_data = len(list(condition_dict.values())[0])
347347
if self.shuffle:
348348
_apply_shuffle(condition_dict, len_data)
349349
for key, data in self._split_condition(
@@ -390,12 +390,12 @@ def find_max_conditions_lengths(self, split):
390390
max_conditions_lengths = {}
391391
for k, v in self.collector_splits[split].items():
392392
if self.batch_size is None:
393-
max_conditions_lengths[k] = len(v["input_points"])
393+
max_conditions_lengths[k] = len(list(v.values())[0])
394394
elif self.repeat:
395395
max_conditions_lengths[k] = self.batch_size
396396
else:
397397
max_conditions_lengths[k] = min(
398-
len(v["input_points"]), self.batch_size
398+
len(list(v.values())[0]), self.batch_size
399399
)
400400
return max_conditions_lengths
401401

0 commit comments

Comments
 (0)