Skip to content

Commit 3ad2c8a

Browse files
committed
Implement graph conditions
1 parent 53b5fd7 commit 3ad2c8a

File tree

7 files changed

+273
-62
lines changed

7 files changed

+273
-62
lines changed

pina/condition/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
"DomainEquationCondition",
55
"InputPointsEquationCondition",
66
"InputOutputPointsCondition",
7+
"GraphInputOutputCondition",
8+
"GraphDataCondition",
9+
"GraphInputEquationCondition",
710
]
811

912
from .condition_interface import ConditionInterface
1013
from .domain_equation_condition import DomainEquationCondition
1114
from .input_equation_condition import InputPointsEquationCondition
1215
from .input_output_condition import InputOutputPointsCondition
16+
from .graph_condition import GraphInputOutputCondition
17+
from .graph_condition import GraphDataCondition
18+
from .graph_condition import GraphInputEquationCondition

pina/condition/condition.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from .input_equation_condition import InputPointsEquationCondition
55
from .input_output_condition import InputOutputPointsCondition
66
from .data_condition import DataConditionInterface
7+
from .graph_condition import (
8+
GraphInputOutputCondition,
9+
GraphInputEquationCondition,
10+
)
711
import warnings
812
from ..utils import custom_warning_format
913

@@ -82,5 +86,9 @@ def __new__(cls, *args, **kwargs):
8286
return DataConditionInterface(**kwargs)
8387
elif sorted_keys == DataConditionInterface.__slots__[0]:
8488
return DataConditionInterface(**kwargs)
89+
elif sorted_keys == sorted(GraphInputOutputCondition.__slots__):
90+
return GraphInputOutputCondition(**kwargs)
91+
elif sorted_keys == sorted(GraphInputEquationCondition.__slots__):
92+
return GraphInputEquationCondition(**kwargs)
8593
else:
8694
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")

pina/condition/condition_interface.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33

44
class ConditionInterface(metaclass=ABCMeta):
55

6-
condition_types = ["physics", "supervised", "unsupervised"]
7-
86
def __init__(self, *args, **kwargs):
9-
self._condition_type = None
107
self._problem = None
118

129
@property
@@ -15,20 +12,4 @@ def problem(self):
1512

1613
@problem.setter
1714
def problem(self, value):
18-
self._problem = value
19-
20-
@property
21-
def condition_type(self):
22-
return self._condition_type
23-
24-
@condition_type.setter
25-
def condition_type(self, values):
26-
if not isinstance(values, (list, tuple)):
27-
values = [values]
28-
for value in values:
29-
if value not in ConditionInterface.condition_types:
30-
raise ValueError(
31-
"Unavailable type of condition, expected one of"
32-
f" {ConditionInterface.condition_types}."
33-
)
34-
self._condition_type = values
15+
self._problem = value

pina/condition/graph_condition.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from .condition_interface import ConditionInterface
2+
from ..graph import Graph
3+
from ..utils import check_consistency
4+
from torch_geometric.data import Data
5+
from ..equation.equation_interface import EquationInterface
6+
7+
8+
class GraphCondition(ConditionInterface):
9+
"""
10+
TODO
11+
"""
12+
13+
__slots__ = ["graph"]
14+
15+
def __new__(cls, graph):
16+
"""
17+
TODO : add docstring
18+
"""
19+
check_consistency(graph, (Graph, Data))
20+
graph = [graph] if isinstance(graph, Data) else graph
21+
22+
if all(g.y is not None for g in graph):
23+
return super().__new__(GraphInputOutputCondition)
24+
else:
25+
return super().__new__(GraphDataCondition)
26+
27+
def __init__(self, graph):
28+
29+
super().__init__()
30+
self.graph = graph
31+
32+
def __setattr__(self, key, value):
33+
if key == "graph":
34+
check_consistency(value, (Graph, Data))
35+
GraphCondition.__dict__[key].__set__(self, value)
36+
elif key in ("_problem", "_condition_type"):
37+
super().__setattr__(key, value)
38+
39+
40+
class GraphInputEquationCondition(ConditionInterface):
41+
42+
__slots__ = ["graph", "equation"]
43+
44+
def __init__(self, graph, equation):
45+
super().__init__()
46+
self.graph = graph
47+
self.equation = equation
48+
49+
def __setattr__(self, key, value):
50+
if key == "graph":
51+
check_consistency(value, (Graph, Data))
52+
GraphInputEquationCondition.__dict__[key].__set__(self, value)
53+
elif key == "equation":
54+
check_consistency(value, (EquationInterface))
55+
GraphInputEquationCondition.__dict__[key].__set__(self, value)
56+
elif key in ("_problem", "_condition_type"):
57+
super().__setattr__(key, value)
58+
59+
60+
# The split between GraphInputOutputCondition and GraphDataCondition
61+
# distinguishes different types of graph conditions passed to problems.
62+
# This separation simplifies consistency checks during problem creation.
63+
class GraphDataCondition(GraphCondition):
64+
pass
65+
66+
67+
class GraphInputOutputCondition(GraphCondition):
68+
pass

pina/data/dataset.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -146,36 +146,31 @@ def __init__(
146146
self, conditions_dict, max_conditions_lengths, automatic_batching
147147
):
148148
super().__init__(conditions_dict, max_conditions_lengths)
149+
self.conditions_length = {
150+
k: len(v["graph"]) for k, v in self.conditions_dict.items()
151+
}
152+
self.length = max(self.conditions_length.values())
153+
149154
self.in_labels = {}
150155
self.out_labels = None
151156
if automatic_batching:
152157
self._getitem_func = self._getitem_int
153158
else:
154159
self._getitem_func = self._getitem_dummy
155160

156-
ex_data = conditions_dict[list(conditions_dict.keys())[0]][
157-
"input_points"
158-
][0]
161+
ex_data = conditions_dict[list(conditions_dict.keys())[0]]["graph"][0]
162+
159163
for name, attr in ex_data.items():
160164
if isinstance(attr, LabelTensor):
161165
self.in_labels[name] = attr.stored_labels
162-
ex_data = conditions_dict[list(conditions_dict.keys())[0]][
163-
"output_points"
164-
][0]
165-
if isinstance(ex_data, LabelTensor):
166-
self.out_labels = ex_data.labels
167166

168167
self._create_graph_batch_from_list = (
169168
self._labelise_batch(self._base_create_graph_batch_from_list)
170169
if self.in_labels
171170
else self._base_create_graph_batch_from_list
172171
)
173-
174-
self._create_output_batch = (
175-
self._labelise_tensor(self._base_create_output_batch)
176-
if self.out_labels is not None
177-
else self._base_create_output_batch
178-
)
172+
if hasattr(ex_data, "y"):
173+
self._divide_batch = self._extract_output(self._divide_batch)
179174

180175
def fetch_from_idx_list(self, idx):
181176
to_return_dict = {}
@@ -184,34 +179,32 @@ def fetch_from_idx_list(self, idx):
184179
condition_len = self.conditions_length[condition]
185180
if self.length > condition_len:
186181
cond_idx = [idx % condition_len for idx in cond_idx]
187-
to_return_dict[condition] = {
188-
k: (
189-
self._create_graph_batch_from_list([v[i] for i in idx])
190-
if isinstance(v, list)
191-
else self._create_output_batch(v[idx])
192-
)
193-
for k, v in data.items()
194-
}
182+
batch = self._create_graph_batch_from_list(
183+
[data["graph"][i] for i in idx]
184+
)
185+
to_return_dict[condition] = self._divide_batch(batch=batch)
186+
return to_return_dict
195187

188+
def _divide_batch(self, batch):
189+
"""
190+
Divide the batch into input and output points
191+
"""
192+
to_return_dict = {}
193+
to_return_dict["input_points"] = batch
194+
if hasattr(batch, "y"):
195+
to_return_dict["output_points"] = batch.y
196196
return to_return_dict
197197

198198
def _base_create_graph_batch_from_list(self, data):
199199
batch = PinaBatch.from_data_list(data)
200200
return batch
201201

202-
def _base_create_output_batch(self, data):
203-
out = data.reshape(-1, *data.shape[2:])
204-
return out
205-
206202
def _getitem_dummy(self, idx):
207203
return idx
208204

209205
def _getitem_int(self, idx):
210206
return {
211-
k: {
212-
k_data: v[k_data][idx % len(v["input_points"])]
213-
for k_data in v.keys()
214-
}
207+
k: {"graph": v["graph"][idx % len(v["graph"])]}
215208
for k, v in self.conditions_dict.items()
216209
}
217210

@@ -251,3 +244,23 @@ def create_graph_batch(self, data):
251244
if isinstance(data[0], Data):
252245
return self._create_graph_batch_from_list(data)
253246
return self._create_output_batch(data)
247+
248+
@staticmethod
249+
def _extract_output(func):
250+
@functools.wraps(func)
251+
def wrapper(*args, **kwargs):
252+
out = func(*args, **kwargs)
253+
out["output_points"] = kwargs["batch"].y
254+
return out
255+
256+
return wrapper
257+
258+
@staticmethod
259+
def _extract_cond_vars(func):
260+
@functools.wraps(func)
261+
def wrapper(*args, **kwargs):
262+
out = func(*args, **kwargs)
263+
out["conditional_variables"] = kwargs["batch"].conditional_vars
264+
return out
265+
266+
return wrapper

pina/problem/zoo/supervised_problem.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1+
import torch
12
from pina.problem import AbstractProblem
23
from pina import Condition
34
from pina import Graph
5+
from pina import LabelTensor
46

57

6-
class SupervisedProblem(AbstractProblem):
8+
class SupervisedProblem:
9+
10+
def __new__(cls, *args, **kwargs):
11+
12+
if sorted(list(kwargs.keys())) == sorted(["input_", "output_"]):
13+
return SupervisedTensorProblem(**kwargs)
14+
elif sorted(list(kwargs.keys())) == sorted(["graph_"]):
15+
return SupervisedGraphProblem(**kwargs)
16+
raise RuntimeError("Invalid arguments for SupervisedProblem")
17+
18+
19+
class SupervisedTensorProblem(AbstractProblem):
720
"""
821
A problem definition for supervised learning in PINA.
922
@@ -29,9 +42,54 @@ def __init__(self, input_, output_):
2942
:param output_: Output data of the problem
3043
:type output_: torch.Tensor
3144
"""
32-
if isinstance(input_, Graph):
33-
input_ = input_.data
45+
if not isinstance(input_, (torch.Tensor, LabelTensor)):
46+
raise ValueError(
47+
"The input data must be a torch.Tensor or a " "LabelTensor"
48+
)
49+
if not isinstance(output_, (torch.Tensor, LabelTensor)):
50+
raise ValueError(
51+
"The output data must be a torch.Tensor or a " "LabelTensor"
52+
)
53+
if isinstance(output_, LabelTensor):
54+
self.output_variables = output_.labels
55+
3456
self.conditions["data"] = Condition(
3557
input_points=input_, output_points=output_
3658
)
3759
super().__init__()
60+
61+
62+
class SupervisedGraphProblem(AbstractProblem):
63+
"""
64+
A problem definition for supervised learning in PINA.
65+
66+
This class allows an easy and straightforward definition of a Supervised problem,
67+
based on a single condition of type `InputOutputPointsCondition`
68+
69+
:Example:
70+
>>> import torch
71+
>>> from pina.graph import RadiusGraph
72+
>>> x = torch.rand((10, 100, 10))
73+
>>> pos = torch.rand((10, 100, 2))
74+
>>> y = torch.rand((10, 100, 2))
75+
>>> input_data = RadiusGraph(x=x, pos=pos, r=.2, y=y)
76+
>>> problem = SupervisedProblem(graph_=input_data)
77+
"""
78+
79+
conditions = dict()
80+
output_variables = None
81+
82+
def __init__(self, graph_):
83+
"""
84+
Initialize the SupervisedProblem class
85+
86+
:param graph_: Input data of the problem
87+
:type graph_: Graph
88+
"""
89+
if not isinstance(graph_, list) or not all(
90+
isinstance(g, Graph) for g in graph_
91+
):
92+
raise ValueError("The input data must be a Graph")
93+
94+
self.conditions["data"] = Condition(graph=graph_)
95+
super().__init__()

0 commit comments

Comments
 (0)