Skip to content

Commit 669d870

Browse files
authored
Fix bug in Collector with Graph data (#456)
* Fix bug in Collector with Graph data * Add comments in DataModule class and bug fix in collate
1 parent 78e4562 commit 669d870

File tree

6 files changed

+254
-66
lines changed

6 files changed

+254
-66
lines changed

pina/collector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
# TODO
3+
"""
4+
from .graph import Graph
15
from .utils import check_consistency
26

37

@@ -52,6 +56,8 @@ def store_fixed_data(self):
5256
# get data
5357
keys = condition.__slots__
5458
values = [getattr(condition, name) for name in keys]
59+
values = [value.data if isinstance(
60+
value, Graph) else value for value in values]
5561
self.data_collections[condition_name] = dict(zip(keys, values))
5662
# condition now is ready
5763
self._is_conditions_ready[condition_name] = True

pina/data/data_module.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import warnings
33
from lightning.pytorch import LightningDataModule
44
import torch
5-
from ..label_tensor import LabelTensor
6-
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
7-
RandomSampler
5+
from torch_geometric.data import Data
6+
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
87
from torch.utils.data.distributed import DistributedSampler
9-
from .dataset import PinaDatasetFactory
8+
from ..label_tensor import LabelTensor
9+
from .dataset import PinaDatasetFactory, PinaTensorDataset
1010
from ..collector import Collector
1111

1212

@@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None):
6161
max_conditions_lengths is None else (
6262
self._collate_standard_dataloader)
6363
self.dataset = dataset
64+
if isinstance(self.dataset, PinaTensorDataset):
65+
self._collate = self._collate_tensor_dataset
66+
else:
67+
self._collate = self._collate_graph_dataset
6468

6569
def _collate_custom_dataloader(self, batch):
6670
return self.dataset.fetch_from_idx_list(batch)
@@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch):
7377
if isinstance(batch, dict):
7478
return batch
7579
conditions_names = batch[0].keys()
76-
7780
# Condition names
7881
for condition_name in conditions_names:
7982
single_cond_dict = {}
@@ -82,16 +85,28 @@ def _collate_standard_dataloader(self, batch):
8285
data_list = [batch[idx][condition_name][arg] for idx in range(
8386
min(len(batch),
8487
self.max_conditions_lengths[condition_name]))]
85-
if isinstance(data_list[0], LabelTensor):
86-
single_cond_dict[arg] = LabelTensor.stack(data_list)
87-
elif isinstance(data_list[0], torch.Tensor):
88-
single_cond_dict[arg] = torch.stack(data_list)
89-
else:
90-
raise NotImplementedError(
91-
f"Data type {type(data_list[0])} not supported")
88+
single_cond_dict[arg] = self._collate(data_list)
89+
9290
batch_dict[condition_name] = single_cond_dict
9391
return batch_dict
9492

93+
@staticmethod
94+
def _collate_tensor_dataset(data_list):
95+
if isinstance(data_list[0], LabelTensor):
96+
return LabelTensor.stack(data_list)
97+
if isinstance(data_list[0], torch.Tensor):
98+
return torch.stack(data_list)
99+
raise RuntimeError("Data must be Tensors or LabelTensor ")
100+
101+
def _collate_graph_dataset(self, data_list):
102+
if isinstance(data_list[0], LabelTensor):
103+
return LabelTensor.cat(data_list)
104+
if isinstance(data_list[0], torch.Tensor):
105+
return torch.cat(data_list)
106+
if isinstance(data_list[0], Data):
107+
return self.dataset.create_graph_batch(data_list)
108+
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
109+
95110
def __call__(self, batch):
96111
return self.callable_function(batch)
97112

@@ -125,7 +140,7 @@ def __init__(self,
125140
batch_size=None,
126141
shuffle=True,
127142
repeat=False,
128-
automatic_batching=False,
143+
automatic_batching=None,
129144
num_workers=0,
130145
pin_memory=False,
131146
):
@@ -158,15 +173,35 @@ def __init__(self,
158173
logging.debug('Start initialization of Pina DataModule')
159174
logging.info('Start initialization of Pina DataModule')
160175
super().__init__()
161-
self.automatic_batching = automatic_batching
176+
177+
# Store fixed attributes
162178
self.batch_size = batch_size
163179
self.shuffle = shuffle
164180
self.repeat = repeat
181+
self.automatic_batching = automatic_batching
182+
if batch_size is None and num_workers != 0:
183+
warnings.warn(
184+
"Setting num_workers when batch_size is None has no effect on "
185+
"the DataLoading process.")
186+
self.num_workers = 0
187+
else:
188+
self.num_workers = num_workers
189+
if batch_size is None and pin_memory:
190+
warnings.warn("Setting pin_memory to True has no effect when "
191+
"batch_size is None.")
192+
self.pin_memory = False
193+
else:
194+
self.pin_memory = pin_memory
195+
196+
# Collect data
197+
collector = Collector(problem)
198+
collector.store_fixed_data()
199+
collector.store_sample_domains()
165200

166201
# Check if the splits are correct
167202
self._check_slit_sizes(train_size, test_size, val_size, predict_size)
168203

169-
# Begin Data splitting
204+
# Split input data into subsets
170205
splits_dict = {}
171206
if train_size > 0:
172207
splits_dict['train'] = train_size
@@ -188,19 +223,6 @@ def __init__(self,
188223
self.predict_dataset = None
189224
else:
190225
self.predict_dataloader = super().predict_dataloader
191-
192-
collector = Collector(problem)
193-
collector.store_fixed_data()
194-
collector.store_sample_domains()
195-
if batch_size is None and num_workers != 0:
196-
warnings.warn(
197-
"Setting num_workers when batch_size is None has no effect on "
198-
"the DataLoading process.")
199-
if batch_size is None and pin_memory:
200-
warnings.warn("Setting pin_memory to True has no effect when "
201-
"batch_size is None.")
202-
self.num_workers = num_workers
203-
self.pin_memory = pin_memory
204226
self.collector_splits = self._create_splits(collector, splits_dict)
205227
self.transfer_batch_to_device = self._transfer_batch_to_device
206228

@@ -316,10 +338,10 @@ def _create_dataloader(self, split, dataset):
316338
if self.batch_size is not None:
317339
sampler = PinaSampler(dataset, shuffle)
318340
if self.automatic_batching:
319-
collate = Collator(self.find_max_conditions_lengths(split))
320-
341+
collate = Collator(self.find_max_conditions_lengths(split),
342+
dataset=dataset)
321343
else:
322-
collate = Collator(None, dataset)
344+
collate = Collator(None, dataset=dataset)
323345
return DataLoader(dataset, self.batch_size,
324346
collate_fn=collate, sampler=sampler,
325347
num_workers=self.num_workers)

pina/data/dataset.py

Lines changed: 91 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""
22
This module provide basic data management functionalities
33
"""
4+
import functools
45
import torch
56
from torch.utils.data import Dataset
67
from abc import abstractmethod
7-
from torch_geometric.data import Batch
8+
from torch_geometric.data import Batch, Data
9+
from pina import LabelTensor
810

911

1012
class PinaDatasetFactory:
@@ -62,7 +64,7 @@ def __init__(self, conditions_dict, max_conditions_lengths,
6264
if automatic_batching:
6365
self._getitem_func = self._getitem_int
6466
else:
65-
self._getitem_func = self._getitem_list
67+
self._getitem_func = self._getitem_dummy
6668

6769
def _getitem_int(self, idx):
6870
return {
@@ -82,7 +84,7 @@ def fetch_from_idx_list(self, idx):
8284
return to_return_dict
8385

8486
@staticmethod
85-
def _getitem_list(idx):
87+
def _getitem_dummy(idx):
8688
return idx
8789

8890
def get_all_data(self):
@@ -102,15 +104,56 @@ def input_points(self):
102104
}
103105

104106

107+
class PinaBatch(Batch):
108+
"""
109+
Add extract function to torch_geometric Batch object
110+
"""
111+
def __init__(self):
112+
113+
super().__init__(self)
114+
115+
def extract(self, labels):
116+
"""
117+
Perform extraction of labels on node features (x)
118+
119+
:param labels: Labels to extract
120+
:type labels: list[str] | tuple[str] | str
121+
:return: Batch object with extraction performed on x
122+
:rtype: PinaBatch
123+
"""
124+
self.x = self.x.extract(labels)
125+
return self
126+
127+
105128
class PinaGraphDataset(PinaDataset):
106129

107130
def __init__(self, conditions_dict, max_conditions_lengths,
108131
automatic_batching):
109132
super().__init__(conditions_dict, max_conditions_lengths)
133+
self.in_labels = {}
134+
self.out_labels = None
110135
if automatic_batching:
111136
self._getitem_func = self._getitem_int
112137
else:
113-
self._getitem_func = self._getitem_list
138+
self._getitem_func = self._getitem_dummy
139+
140+
ex_data = conditions_dict[list(conditions_dict.keys())[
141+
0]]['input_points'][0]
142+
for name, attr in ex_data.items():
143+
if isinstance(attr, LabelTensor):
144+
self.in_labels[name] = attr.stored_labels
145+
ex_data = conditions_dict[list(conditions_dict.keys())[
146+
0]]['output_points'][0]
147+
if isinstance(ex_data, LabelTensor):
148+
self.out_labels = ex_data.labels
149+
150+
self._create_graph_batch_from_list = self._labelise_batch(
151+
self._base_create_graph_batch_from_list) if self.in_labels \
152+
else self._base_create_graph_batch_from_list
153+
154+
self._create_output_batch = self._labelise_tensor(
155+
self._base_create_output_batch) if self.out_labels is not None \
156+
else self._base_create_output_batch
114157

115158
def fetch_from_idx_list(self, idx):
116159
to_return_dict = {}
@@ -119,17 +162,24 @@ def fetch_from_idx_list(self, idx):
119162
condition_len = self.conditions_length[condition]
120163
if self.length > condition_len:
121164
cond_idx = [idx % condition_len for idx in cond_idx]
122-
to_return_dict[condition] = {k: Batch.from_data_list([
123-
v[i] for i in cond_idx])
124-
if isinstance(v, list)
125-
else v[
126-
cond_idx].reshape(
127-
-1, *v[cond_idx].shape[2:])
128-
for k, v in data.items()
129-
}
165+
to_return_dict[condition] = {
166+
k: self._create_graph_batch_from_list([v[i] for i in idx])
167+
if isinstance(v, list)
168+
else self._create_output_batch(v[idx])
169+
for k, v in data.items()
170+
}
171+
130172
return to_return_dict
131173

132-
def _getitem_list(self, idx):
174+
def _base_create_graph_batch_from_list(self, data):
175+
batch = PinaBatch.from_data_list(data)
176+
return batch
177+
178+
def _base_create_output_batch(self, data):
179+
out = data.reshape(-1, *data.shape[2:])
180+
return out
181+
182+
def _getitem_dummy(self, idx):
133183
return idx
134184

135185
def _getitem_int(self, idx):
@@ -144,3 +194,31 @@ def get_all_data(self):
144194

145195
def __getitem__(self, idx):
146196
return self._getitem_func(idx)
197+
198+
def _labelise_batch(self, func):
199+
@functools.wraps(func)
200+
def wrapper(*args, **kwargs):
201+
batch = func(*args, **kwargs)
202+
for k, v in self.in_labels.items():
203+
tmp = batch[k]
204+
tmp.labels = v
205+
batch[k] = tmp
206+
return batch
207+
return wrapper
208+
209+
def _labelise_tensor(self, func):
210+
@functools.wraps(func)
211+
def wrapper(*args, **kwargs):
212+
out = func(*args, **kwargs)
213+
if isinstance(out, LabelTensor):
214+
out.labels = self.out_labels
215+
return out
216+
return wrapper
217+
218+
def create_graph_batch(self, data):
219+
"""
220+
# TODO
221+
"""
222+
if isinstance(data[0], Data):
223+
return self._create_graph_batch_from_list(data)
224+
return self._create_output_batch(data)

pina/graph.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,14 @@ def __init__(
108108
x)
109109

110110
# Perform the graph construction
111-
self._build_graph_list(x, pos, edge_index, edge_attr, additional_params)
111+
self._build_graph_list(
112+
x, pos, edge_index, edge_attr, additional_params)
112113

113114
def _build_graph_list(self, x, pos, edge_index, edge_attr,
114115
additional_params):
115116
for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)):
116-
if isinstance(x_, LabelTensor):
117-
x_ = x_.tensor
118117
add_params_local = {k: v[i] for k, v in additional_params.items()}
119118
if edge_attr is not None:
120-
121119
self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_,
122120
edge_attr=edge_attr[i],
123121
**add_params_local))
@@ -127,7 +125,8 @@ def _build_graph_list(self, x, pos, edge_index, edge_attr,
127125

128126
@staticmethod
129127
def _build_edge_attr(x, pos, edge_index):
130-
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
128+
distance = torch.abs(pos[edge_index[0]] -
129+
pos[edge_index[1]]).as_subclass(torch.Tensor)
131130
return distance
132131

133132
@staticmethod
@@ -165,7 +164,8 @@ def _check_input_consistency(x, pos, edge_index=None):
165164
# If edge_index is a 3D tensor, we split it into a list of 2D tensors
166165
if edge_index is not None:
167166
if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3:
168-
edge_index = [edge_index[i] for i in range(edge_index.shape[0])]
167+
edge_index = [edge_index[i]
168+
for i in range(edge_index.shape[0])]
169169
elif not (isinstance(edge_index, list) and all(
170170
t.ndim == 2 for t in edge_index)) and not (
171171
isinstance(edge_index,
@@ -219,7 +219,7 @@ def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len,
219219
if isinstance(edge_attr, list):
220220
if len(edge_attr) != data_len:
221221
raise TypeError("edge_attr must have the same length as x "
222-
"and pos.")
222+
"and pos.")
223223
return [edge_attr] * data_len
224224

225225
if build_edge_attr:
@@ -258,6 +258,8 @@ def _radius_graph(points, r):
258258
"""
259259
dist = torch.cdist(points, points, p=2)
260260
edge_index = torch.nonzero(dist <= r, as_tuple=False).t()
261+
if isinstance(edge_index, LabelTensor):
262+
edge_index = edge_index.tensor
261263
return edge_index
262264

263265

@@ -293,4 +295,6 @@ def _knn_graph(points, k):
293295
row = torch.arange(points.size(0)).repeat_interleave(k)
294296
col = knn_indices.flatten()
295297
edge_index = torch.stack([row, col], dim=0)
298+
if isinstance(edge_index, LabelTensor):
299+
edge_index = edge_index.tensor
296300
return edge_index

0 commit comments

Comments
 (0)