22import warnings
33from lightning .pytorch import LightningDataModule
44import torch
5- from torch_geometric .data import Data , Batch
5+ from torch_geometric .data import Data
66from torch .utils .data import DataLoader , SequentialSampler , RandomSampler
77from torch .utils .data .distributed import DistributedSampler
88from ..label_tensor import LabelTensor
9- from .dataset import PinaDatasetFactory
9+ from .dataset import PinaDatasetFactory , PinaTensorDataset
1010from ..collector import Collector
1111
1212
@@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None):
6161 max_conditions_lengths is None else (
6262 self ._collate_standard_dataloader )
6363 self .dataset = dataset
64+ if isinstance (self .dataset , PinaTensorDataset ):
65+ self ._collate = self ._collate_tensor_dataset
66+ else :
67+ self ._collate = self ._collate_graph_dataset
6468
6569 def _collate_custom_dataloader (self , batch ):
6670 return self .dataset .fetch_from_idx_list (batch )
@@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch):
7377 if isinstance (batch , dict ):
7478 return batch
7579 conditions_names = batch [0 ].keys ()
76-
7780 # Condition names
7881 for condition_name in conditions_names :
7982 single_cond_dict = {}
@@ -82,15 +85,28 @@ def _collate_standard_dataloader(self, batch):
8285 data_list = [batch [idx ][condition_name ][arg ] for idx in range (
8386 min (len (batch ),
8487 self .max_conditions_lengths [condition_name ]))]
85- if isinstance (data_list [0 ], LabelTensor ):
86- single_cond_dict [arg ] = LabelTensor .stack (data_list )
87- elif isinstance (data_list [0 ], torch .Tensor ):
88- single_cond_dict [arg ] = torch .stack (data_list )
89- elif isinstance (data_list [0 ], Data ):
90- single_cond_dict [arg ] = Batch .from_data_list (data_list )
88+ single_cond_dict [arg ] = self ._collate (data_list )
89+
9190 batch_dict [condition_name ] = single_cond_dict
9291 return batch_dict
9392
93+ @staticmethod
94+ def _collate_tensor_dataset (data_list ):
95+ if isinstance (data_list [0 ], LabelTensor ):
96+ return LabelTensor .stack (data_list )
97+ if isinstance (data_list [0 ], torch .Tensor ):
98+ return torch .stack (data_list )
99+ raise RuntimeError ("Data must be Tensors or LabelTensor " )
100+
101+ def _collate_graph_dataset (self , data_list ):
102+ if isinstance (data_list [0 ], LabelTensor ):
103+ return LabelTensor .cat (data_list )
104+ if isinstance (data_list [0 ], torch .Tensor ):
105+ return torch .cat (data_list )
106+ if isinstance (data_list [0 ], Data ):
107+ return self .dataset .create_graph_batch (data_list )
108+ raise RuntimeError ("Data must be Tensors or LabelTensor or pyG Data" )
109+
94110 def __call__ (self , batch ):
95111 return self .callable_function (batch )
96112
@@ -157,14 +173,35 @@ def __init__(self,
157173 logging .debug ('Start initialization of Pina DataModule' )
158174 logging .info ('Start initialization of Pina DataModule' )
159175 super ().__init__ ()
176+
177+ # Store fixed attributes
160178 self .batch_size = batch_size
161179 self .shuffle = shuffle
162180 self .repeat = repeat
181+ self .automatic_batching = automatic_batching
182+ if batch_size is None and num_workers != 0 :
183+ warnings .warn (
184+ "Setting num_workers when batch_size is None has no effect on "
185+ "the DataLoading process." )
186+ self .num_workers = 0
187+ else :
188+ self .num_workers = num_workers
189+ if batch_size is None and pin_memory :
190+ warnings .warn ("Setting pin_memory to True has no effect when "
191+ "batch_size is None." )
192+ self .pin_memory = False
193+ else :
194+ self .pin_memory = pin_memory
195+
196+ # Collect data
197+ collector = Collector (problem )
198+ collector .store_fixed_data ()
199+ collector .store_sample_domains ()
163200
164201 # Check if the splits are correct
165202 self ._check_slit_sizes (train_size , test_size , val_size , predict_size )
166203
167- # Begin Data splitting
204+ # Split input data into subsets
168205 splits_dict = {}
169206 if train_size > 0 :
170207 splits_dict ['train' ] = train_size
@@ -186,23 +223,6 @@ def __init__(self,
186223 self .predict_dataset = None
187224 else :
188225 self .predict_dataloader = super ().predict_dataloader
189-
190- collector = Collector (problem )
191- collector .store_fixed_data ()
192- collector .store_sample_domains ()
193-
194- self .automatic_batching = self ._set_automatic_batching_option (
195- collector , automatic_batching )
196-
197- if batch_size is None and num_workers != 0 :
198- warnings .warn (
199- "Setting num_workers when batch_size is None has no effect on "
200- "the DataLoading process." )
201- if batch_size is None and pin_memory :
202- warnings .warn ("Setting pin_memory to True has no effect when "
203- "batch_size is None." )
204- self .num_workers = num_workers
205- self .pin_memory = pin_memory
206226 self .collector_splits = self ._create_splits (collector , splits_dict )
207227 self .transfer_batch_to_device = self ._transfer_batch_to_device
208228
@@ -318,10 +338,10 @@ def _create_dataloader(self, split, dataset):
318338 if self .batch_size is not None :
319339 sampler = PinaSampler (dataset , shuffle )
320340 if self .automatic_batching :
321- collate = Collator (self .find_max_conditions_lengths (split ))
322-
341+ collate = Collator (self .find_max_conditions_lengths (split ),
342+ dataset = dataset )
323343 else :
324- collate = Collator (None , dataset )
344+ collate = Collator (None , dataset = dataset )
325345 return DataLoader (dataset , self .batch_size ,
326346 collate_fn = collate , sampler = sampler ,
327347 num_workers = self .num_workers )
@@ -395,27 +415,6 @@ def _check_slit_sizes(train_size, test_size, val_size, predict_size):
395415 if abs (train_size + test_size + val_size + predict_size - 1 ) > 1e-6 :
396416 raise ValueError ("The sum of the splits must be 1" )
397417
398- @staticmethod
399- def _set_automatic_batching_option (collector , automatic_batching ):
400- """
401- Determines whether automatic batching should be enabled.
402-
403- If all 'input_points' in the collector's data collections are
404- tensors (torch.Tensor or LabelTensor), it respects the provided
405- `automatic_batching` value; otherwise, mainly in the Graph scenario,
406- it forces automatic batching on.
407-
408- :param Collector collector: Collector object with contains all data
409- retrieved from input conditions
410- :param bool automatic_batching : If the user wants to enable automatic
411- batching or not
412- """
413- if all (isinstance (v ['input_points' ], (torch .Tensor , LabelTensor ))
414- for v in collector .data_collections .values ()):
415- return automatic_batching if automatic_batching is not None \
416- else False
417- return True
418-
419418 @property
420419 def input_points (self ):
421420 """
0 commit comments