Skip to content

Commit f32af3c

Browse files
committed
Add comments in DataModule class and bug fix in collate
1 parent 1d0ea1c commit f32af3c

File tree

4 files changed

+169
-75
lines changed

4 files changed

+169
-75
lines changed

pina/data/data_module.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import warnings
33
from lightning.pytorch import LightningDataModule
44
import torch
5-
from torch_geometric.data import Data, Batch
5+
from torch_geometric.data import Data
66
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
77
from torch.utils.data.distributed import DistributedSampler
88
from ..label_tensor import LabelTensor
9-
from .dataset import PinaDatasetFactory
9+
from .dataset import PinaDatasetFactory, PinaTensorDataset
1010
from ..collector import Collector
1111

1212

@@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None):
6161
max_conditions_lengths is None else (
6262
self._collate_standard_dataloader)
6363
self.dataset = dataset
64+
if isinstance(self.dataset, PinaTensorDataset):
65+
self._collate = self._collate_tensor_dataset
66+
else:
67+
self._collate = self._collate_graph_dataset
6468

6569
def _collate_custom_dataloader(self, batch):
6670
return self.dataset.fetch_from_idx_list(batch)
@@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch):
7377
if isinstance(batch, dict):
7478
return batch
7579
conditions_names = batch[0].keys()
76-
7780
# Condition names
7881
for condition_name in conditions_names:
7982
single_cond_dict = {}
@@ -82,15 +85,28 @@ def _collate_standard_dataloader(self, batch):
8285
data_list = [batch[idx][condition_name][arg] for idx in range(
8386
min(len(batch),
8487
self.max_conditions_lengths[condition_name]))]
85-
if isinstance(data_list[0], LabelTensor):
86-
single_cond_dict[arg] = LabelTensor.stack(data_list)
87-
elif isinstance(data_list[0], torch.Tensor):
88-
single_cond_dict[arg] = torch.stack(data_list)
89-
elif isinstance(data_list[0], Data):
90-
single_cond_dict[arg] = Batch.from_data_list(data_list)
88+
single_cond_dict[arg] = self._collate(data_list)
89+
9190
batch_dict[condition_name] = single_cond_dict
9291
return batch_dict
9392

93+
@staticmethod
94+
def _collate_tensor_dataset(data_list):
95+
if isinstance(data_list[0], LabelTensor):
96+
return LabelTensor.stack(data_list)
97+
if isinstance(data_list[0], torch.Tensor):
98+
return torch.stack(data_list)
99+
raise RuntimeError("Data must be Tensors or LabelTensor ")
100+
101+
def _collate_graph_dataset(self, data_list):
102+
if isinstance(data_list[0], LabelTensor):
103+
return LabelTensor.cat(data_list)
104+
if isinstance(data_list[0], torch.Tensor):
105+
return torch.cat(data_list)
106+
if isinstance(data_list[0], Data):
107+
return self.dataset.create_graph_batch(data_list)
108+
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
109+
94110
def __call__(self, batch):
95111
return self.callable_function(batch)
96112

@@ -157,14 +173,36 @@ def __init__(self,
157173
logging.debug('Start initialization of Pina DataModule')
158174
logging.info('Start initialization of Pina DataModule')
159175
super().__init__()
176+
177+
# Store fixed attributes
160178
self.batch_size = batch_size
161179
self.shuffle = shuffle
162180
self.repeat = repeat
181+
self.automatic_batching = automatic_batching if automatic_batching \
182+
is not None else False
183+
if batch_size is None and num_workers != 0:
184+
warnings.warn(
185+
"Setting num_workers when batch_size is None has no effect on "
186+
"the DataLoading process.")
187+
self.num_workers = 0
188+
else:
189+
self.num_workers = num_workers
190+
if batch_size is None and pin_memory:
191+
warnings.warn("Setting pin_memory to True has no effect when "
192+
"batch_size is None.")
193+
self.pin_memory = False
194+
else:
195+
self.pin_memory = pin_memory
196+
197+
# Collect data
198+
collector = Collector(problem)
199+
collector.store_fixed_data()
200+
collector.store_sample_domains()
163201

164202
# Check if the splits are correct
165203
self._check_slit_sizes(train_size, test_size, val_size, predict_size)
166204

167-
# Begin Data splitting
205+
# Split input data into subsets
168206
splits_dict = {}
169207
if train_size > 0:
170208
splits_dict['train'] = train_size
@@ -186,23 +224,6 @@ def __init__(self,
186224
self.predict_dataset = None
187225
else:
188226
self.predict_dataloader = super().predict_dataloader
189-
190-
collector = Collector(problem)
191-
collector.store_fixed_data()
192-
collector.store_sample_domains()
193-
194-
self.automatic_batching = self._set_automatic_batching_option(
195-
collector, automatic_batching)
196-
197-
if batch_size is None and num_workers != 0:
198-
warnings.warn(
199-
"Setting num_workers when batch_size is None has no effect on "
200-
"the DataLoading process.")
201-
if batch_size is None and pin_memory:
202-
warnings.warn("Setting pin_memory to True has no effect when "
203-
"batch_size is None.")
204-
self.num_workers = num_workers
205-
self.pin_memory = pin_memory
206227
self.collector_splits = self._create_splits(collector, splits_dict)
207228
self.transfer_batch_to_device = self._transfer_batch_to_device
208229

@@ -318,10 +339,10 @@ def _create_dataloader(self, split, dataset):
318339
if self.batch_size is not None:
319340
sampler = PinaSampler(dataset, shuffle)
320341
if self.automatic_batching:
321-
collate = Collator(self.find_max_conditions_lengths(split))
322-
342+
collate = Collator(self.find_max_conditions_lengths(split),
343+
dataset=dataset)
323344
else:
324-
collate = Collator(None, dataset)
345+
collate = Collator(None, dataset=dataset)
325346
return DataLoader(dataset, self.batch_size,
326347
collate_fn=collate, sampler=sampler,
327348
num_workers=self.num_workers)
@@ -395,27 +416,6 @@ def _check_slit_sizes(train_size, test_size, val_size, predict_size):
395416
if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6:
396417
raise ValueError("The sum of the splits must be 1")
397418

398-
@staticmethod
399-
def _set_automatic_batching_option(collector, automatic_batching):
400-
"""
401-
Determines whether automatic batching should be enabled.
402-
403-
If all 'input_points' in the collector's data collections are
404-
tensors (torch.Tensor or LabelTensor), it respects the provided
405-
`automatic_batching` value; otherwise, mainly in the Graph scenario,
406-
it forces automatic batching on.
407-
408-
:param Collector collector: Collector object with contains all data
409-
retrieved from input conditions
410-
:param bool automatic_batching : If the user wants to enable automatic
411-
batching or not
412-
"""
413-
if all(isinstance(v['input_points'], (torch.Tensor, LabelTensor))
414-
for v in collector.data_collections.values()):
415-
return automatic_batching if automatic_batching is not None \
416-
else False
417-
return True
418-
419419
@property
420420
def input_points(self):
421421
"""

pina/data/dataset.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from torch.utils.data import Dataset
77
from abc import abstractmethod
8-
from torch_geometric.data import Batch
8+
from torch_geometric.data import Batch, Data
99
from pina import LabelTensor
1010

1111

@@ -64,7 +64,7 @@ def __init__(self, conditions_dict, max_conditions_lengths,
6464
if automatic_batching:
6565
self._getitem_func = self._getitem_int
6666
else:
67-
self._getitem_func = self._getitem_list
67+
self._getitem_func = self._getitem_dummy
6868

6969
def _getitem_int(self, idx):
7070
return {
@@ -84,7 +84,7 @@ def fetch_from_idx_list(self, idx):
8484
return to_return_dict
8585

8686
@staticmethod
87-
def _getitem_list(idx):
87+
def _getitem_dummy(idx):
8888
return idx
8989

9090
def get_all_data(self):
@@ -111,6 +111,11 @@ def __init__(self, conditions_dict, max_conditions_lengths,
111111
super().__init__(conditions_dict, max_conditions_lengths)
112112
self.in_labels = {}
113113
self.out_labels = None
114+
if automatic_batching:
115+
self._getitem_func = self._getitem_int
116+
else:
117+
self._getitem_func = self._getitem_dummy
118+
114119
ex_data = conditions_dict[list(conditions_dict.keys())[
115120
0]]['input_points'][0]
116121
for name, attr in ex_data.items():
@@ -137,22 +142,25 @@ def fetch_from_idx_list(self, idx):
137142
if self.length > condition_len:
138143
cond_idx = [idx % condition_len for idx in cond_idx]
139144
to_return_dict[condition] = {
140-
k: self._create_graph_batch_from_list(v, cond_idx)
145+
k: self._create_graph_batch_from_list([v[i] for i in idx])
141146
if isinstance(v, list)
142-
else self._create_output_batch(v, cond_idx)
147+
else self._create_output_batch(v[idx])
143148
for k, v in data.items()
144149
}
145150

146151
return to_return_dict
147152

148-
def _base_create_graph_batch_from_list(self, data, idx):
149-
batch = Batch.from_data_list([data[i] for i in idx])
153+
def _base_create_graph_batch_from_list(self, data):
154+
batch = Batch.from_data_list(data)
150155
return batch
151156

152-
def _base_create_output_batch(self, data, idx):
153-
out = data[idx].reshape(-1, *data[idx].shape[2:])
157+
def _base_create_output_batch(self, data):
158+
out = data.reshape(-1, *data.shape[2:])
154159
return out
155160

161+
def _getitem_dummy(self, idx):
162+
return idx
163+
156164
def _getitem_int(self, idx):
157165
return {
158166
k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data
@@ -164,8 +172,7 @@ def get_all_data(self):
164172
return self.fetch_from_idx_list(index)
165173

166174
def __getitem__(self, idx):
167-
return self._getitem_int(idx) if isinstance(idx, int) else \
168-
self.fetch_from_idx_list(idx=idx)
175+
return self._getitem_func(idx)
169176

170177
def _labelise_batch(self, func):
171178
@functools.wraps(func)
@@ -186,3 +193,11 @@ def wrapper(*args, **kwargs):
186193
out.labels = self.out_labels
187194
return out
188195
return wrapper
196+
197+
def create_graph_batch(self, data):
198+
"""
199+
# TODO
200+
"""
201+
if isinstance(data[0], Data):
202+
return self._create_graph_batch_from_list(data)
203+
return self._create_output_batch(data)

pina/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def _build_graph_list(self, x, pos, edge_index, edge_attr,
125125

126126
@staticmethod
127127
def _build_edge_attr(x, pos, edge_index):
128-
distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]])
128+
distance = torch.abs(pos[edge_index[0]] -
129+
pos[edge_index[1]]).as_subclass(torch.Tensor)
129130
return distance
130131

131132
@staticmethod

0 commit comments

Comments
 (0)