Skip to content

Commit 6f52c79

Browse files
committed
Fix bug in Collector with Graph data
1 parent 78e4562 commit 6f52c79

File tree

6 files changed

+101
-31
lines changed

6 files changed

+101
-31
lines changed

pina/collector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .graph import Graph
12
from .utils import check_consistency
23

34

@@ -52,6 +53,8 @@ def store_fixed_data(self):
5253
# get data
5354
keys = condition.__slots__
5455
values = [getattr(condition, name) for name in keys]
56+
values = [value.data if isinstance(
57+
value, Graph) else value for value in values]
5558
self.data_collections[condition_name] = dict(zip(keys, values))
5659
# condition now is ready
5760
self._is_conditions_ready[condition_name] = True

pina/data/data_module.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from lightning.pytorch import LightningDataModule
44
import torch
55
from ..label_tensor import LabelTensor
6-
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
7-
RandomSampler
6+
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
87
from torch.utils.data.distributed import DistributedSampler
98
from .dataset import PinaDatasetFactory
109
from ..collector import Collector
10+
from torch_geometric.data import Data, Batch
1111

1212

1313
class DummyDataloader:
@@ -86,9 +86,8 @@ def _collate_standard_dataloader(self, batch):
8686
single_cond_dict[arg] = LabelTensor.stack(data_list)
8787
elif isinstance(data_list[0], torch.Tensor):
8888
single_cond_dict[arg] = torch.stack(data_list)
89-
else:
90-
raise NotImplementedError(
91-
f"Data type {type(data_list[0])} not supported")
89+
elif isinstance(data_list[0], Data):
90+
single_cond_dict[arg] = Batch.from_data_list(data_list)
9291
batch_dict[condition_name] = single_cond_dict
9392
return batch_dict
9493

@@ -125,7 +124,7 @@ def __init__(self,
125124
batch_size=None,
126125
shuffle=True,
127126
repeat=False,
128-
automatic_batching=False,
127+
automatic_batching=None,
129128
num_workers=0,
130129
pin_memory=False,
131130
):
@@ -158,7 +157,6 @@ def __init__(self,
158157
logging.debug('Start initialization of Pina DataModule')
159158
logging.info('Start initialization of Pina DataModule')
160159
super().__init__()
161-
self.automatic_batching = automatic_batching
162160
self.batch_size = batch_size
163161
self.shuffle = shuffle
164162
self.repeat = repeat
@@ -192,6 +190,10 @@ def __init__(self,
192190
collector = Collector(problem)
193191
collector.store_fixed_data()
194192
collector.store_sample_domains()
193+
194+
self.automatic_batching = self._set_automatic_batching_option(
195+
collector, automatic_batching)
196+
195197
if batch_size is None and num_workers != 0:
196198
warnings.warn(
197199
"Setting num_workers when batch_size is None has no effect on "
@@ -393,6 +395,27 @@ def _check_slit_sizes(train_size, test_size, val_size, predict_size):
393395
if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6:
394396
raise ValueError("The sum of the splits must be 1")
395397

398+
@staticmethod
399+
def _set_automatic_batching_option(collector, automatic_batching):
400+
"""
401+
Determines whether automatic batching should be enabled.
402+
403+
If all 'input_points' in the collector's data collections are
404+
tensors (torch.Tensor or LabelTensor), it respects the provided
405+
`automatic_batching` value; otherwise, mainly in the Graph scenario,
406+
it forces automatic batching on.
407+
408+
:param Collector collector: Collector object with contains all data
409+
retrieved from input conditions
410+
:param bool automatic_batching : If the user wants to enable automatic
411+
batching or not
412+
"""
413+
if all(isinstance(v['input_points'], (torch.Tensor, LabelTensor))
414+
for v in collector.data_collections.values()):
415+
return automatic_batching if automatic_batching is not None \
416+
else False
417+
return True
418+
396419
@property
397420
def input_points(self):
398421
"""

pina/data/dataset.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
This module provide basic data management functionalities
33
"""
44
import torch
5+
import functools
56
from torch.utils.data import Dataset
67
from abc import abstractmethod
78
from torch_geometric.data import Batch
9+
from pina import LabelTensor
810

911

1012
class PinaDatasetFactory:
@@ -107,10 +109,25 @@ class PinaGraphDataset(PinaDataset):
107109
def __init__(self, conditions_dict, max_conditions_lengths,
108110
automatic_batching):
109111
super().__init__(conditions_dict, max_conditions_lengths)
110-
if automatic_batching:
111-
self._getitem_func = self._getitem_int
112-
else:
113-
self._getitem_func = self._getitem_list
112+
self.in_labels = {}
113+
self.out_labels = None
114+
ex_data = conditions_dict[list(conditions_dict.keys())[
115+
0]]['input_points'][0]
116+
for name, attr in ex_data.items():
117+
if isinstance(attr, LabelTensor):
118+
self.in_labels[name] = attr.stored_labels
119+
ex_data = conditions_dict[list(conditions_dict.keys())[
120+
0]]['output_points'][0]
121+
if isinstance(ex_data, LabelTensor):
122+
self.out_labels = ex_data.labels
123+
124+
self._create_graph_batch_from_list = self._labelise_batch(
125+
self._base_create_graph_batch_from_list) if self.in_labels \
126+
else self._base_create_graph_batch_from_list
127+
128+
self._create_output_batch = self._labelise_tensor(
129+
self._base_create_output_batch) if self.out_labels is not None \
130+
else self._base_create_output_batch
114131

115132
def fetch_from_idx_list(self, idx):
116133
to_return_dict = {}
@@ -119,18 +136,22 @@ def fetch_from_idx_list(self, idx):
119136
condition_len = self.conditions_length[condition]
120137
if self.length > condition_len:
121138
cond_idx = [idx % condition_len for idx in cond_idx]
122-
to_return_dict[condition] = {k: Batch.from_data_list([
123-
v[i] for i in cond_idx])
124-
if isinstance(v, list)
125-
else v[
126-
cond_idx].reshape(
127-
-1, *v[cond_idx].shape[2:])
128-
for k, v in data.items()
129-
}
139+
to_return_dict[condition] = {
140+
k: self._create_graph_batch_from_list(v, cond_idx)
141+
if isinstance(v, list)
142+
else self._create_output_batch(v, cond_idx)
143+
for k, v in data.items()
144+
}
145+
130146
return to_return_dict
131147

132-
def _getitem_list(self, idx):
133-
return idx
148+
def _base_create_graph_batch_from_list(self, data, idx):
149+
batch = Batch.from_data_list([data[i] for i in idx])
150+
return batch
151+
152+
def _base_create_output_batch(self, data, idx):
153+
out = data[idx].reshape(-1, *data[idx].shape[2:])
154+
return out
134155

135156
def _getitem_int(self, idx):
136157
return {
@@ -143,4 +164,25 @@ def get_all_data(self):
143164
return self.fetch_from_idx_list(index)
144165

145166
def __getitem__(self, idx):
146-
return self._getitem_func(idx)
167+
return self._getitem_int(idx) if isinstance(idx, int) else \
168+
self.fetch_from_idx_list(idx=idx)
169+
170+
def _labelise_batch(self, func):
171+
@functools.wraps(func)
172+
def wrapper(*args, **kwargs):
173+
batch = func(*args, **kwargs)
174+
for k, v in self.in_labels.items():
175+
tmp = batch[k]
176+
tmp.labels = v
177+
batch[k] = tmp
178+
return batch
179+
return wrapper
180+
181+
def _labelise_tensor(self, func):
182+
@functools.wraps(func)
183+
def wrapper(*args, **kwargs):
184+
out = func(*args, **kwargs)
185+
if isinstance(out, LabelTensor):
186+
out.labels = self.out_labels
187+
return out
188+
return wrapper

pina/graph.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,14 @@ def __init__(
108108
x)
109109

110110
# Perform the graph construction
111-
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
111+
self._build_graph_list(
112+
x, pos, edge_index, edge_attr, additional_params)
112113

113114
def _build_graph_list(self, x, pos, edge_index, edge_attr,
114115
additional_params):
115116
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
116-
if isinstance(x_, LabelTensor):
117-
x_ = x_.tensor
118117
add_params_local = {k: v[i] for k, v in additional_params.items()}
119118
if edge_attr is not None:
120-
121119
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
122120
edge_attr=edge_attr[i],
123121
**add_params_local))
@@ -165,7 +163,8 @@ def _check_input_consistency(x, pos, edge_index=None):
165163
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
166164
if edge_index is not None:
167165
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
168-
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
166+
edge_index = [edge_index[i]
167+
for i in range(edge_index.shape[0])]
169168
elif not (isinstance(edge_index, list) and all(
170169
t.ndim == 2 for t in edge_index)) and not (
171170
isinstance(edge_index,
@@ -219,7 +218,7 @@ def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
219218
if isinstance(edge_attr, list):
220219
if len(edge_attr) != data_len:
221220
raise TypeError("edge_attr must have the same length as x "
222-
"and pos.")
221+
"and pos.")
223222
return [edge_attr] * data_len
224223

225224
if build_edge_attr:
@@ -258,6 +257,8 @@ def _radius_graph(points, r):
258257
"""
259258
dist = torch.cdist(points, points, p=2)
260259
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
260+
if isinstance(edge_index, LabelTensor):
261+
edge_index = edge_index.tensor
261262
return edge_index
262263

263264

@@ -293,4 +294,6 @@ def _knn_graph(points, k):
293294
row = torch.arange(points.size(0)).repeat_interleave(k)
294295
col = knn_indices.flatten()
295296
edge_index = torch.stack([row, col], dim=0)
297+
if isinstance(edge_index, LabelTensor):
298+
edge_index = edge_index.tensor
296299
return edge_index

pina/trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ def __init__(self,
105105
# checking compilation and automatic batching
106106
if compile is None or sys.platform == "win32":
107107
compile = False
108-
if automatic_batching is None:
109-
automatic_batching = False
110108

111109
# set attributes
112110
self.compile = compile

pina/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def labelize_forward(forward, input_variables, output_variables):
4848
:type output_variables: list[str] | tuple[str]
4949
"""
5050
def wrapper(x):
51-
x = x.extract(input_variables)
51+
if isinstance(x, LabelTensor):
52+
x = x.extract(input_variables)
5253
output = forward(x)
5354
# keep it like this, directly using LabelTensor(...) raises errors
5455
# when compiling the code

0 commit comments

Comments
 (0)