22import warnings
33from lightning .pytorch import LightningDataModule
44import torch
5- from ..label_tensor import LabelTensor
6- from torch .utils .data import DataLoader , BatchSampler , SequentialSampler , \
7- RandomSampler
5+ from torch_geometric .data import Data
6+ from torch .utils .data import DataLoader , SequentialSampler , RandomSampler
87from torch .utils .data .distributed import DistributedSampler
9- from .dataset import PinaDatasetFactory
8+ from ..label_tensor import LabelTensor
9+ from .dataset import PinaDatasetFactory , PinaTensorDataset
1010from ..collector import Collector
1111
1212
@@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None):
6161 max_conditions_lengths is None else (
6262 self ._collate_standard_dataloader )
6363 self .dataset = dataset
64+ if isinstance (self .dataset , PinaTensorDataset ):
65+ self ._collate = self ._collate_tensor_dataset
66+ else :
67+ self ._collate = self ._collate_graph_dataset
6468
6569 def _collate_custom_dataloader (self , batch ):
6670 return self .dataset .fetch_from_idx_list (batch )
@@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch):
7377 if isinstance (batch , dict ):
7478 return batch
7579 conditions_names = batch [0 ].keys ()
76-
7780 # Condition names
7881 for condition_name in conditions_names :
7982 single_cond_dict = {}
@@ -82,16 +85,28 @@ def _collate_standard_dataloader(self, batch):
8285 data_list = [batch [idx ][condition_name ][arg ] for idx in range (
8386 min (len (batch ),
8487 self .max_conditions_lengths [condition_name ]))]
85- if isinstance (data_list [0 ], LabelTensor ):
86- single_cond_dict [arg ] = LabelTensor .stack (data_list )
87- elif isinstance (data_list [0 ], torch .Tensor ):
88- single_cond_dict [arg ] = torch .stack (data_list )
89- else :
90- raise NotImplementedError (
91- f"Data type { type (data_list [0 ])} not supported" )
88+ single_cond_dict [arg ] = self ._collate (data_list )
89+
9290 batch_dict [condition_name ] = single_cond_dict
9391 return batch_dict
9492
93+ @staticmethod
94+ def _collate_tensor_dataset (data_list ):
95+ if isinstance (data_list [0 ], LabelTensor ):
96+ return LabelTensor .stack (data_list )
97+ if isinstance (data_list [0 ], torch .Tensor ):
98+ return torch .stack (data_list )
99+ raise RuntimeError ("Data must be Tensors or LabelTensor " )
100+
101+ def _collate_graph_dataset (self , data_list ):
102+ if isinstance (data_list [0 ], LabelTensor ):
103+ return LabelTensor .cat (data_list )
104+ if isinstance (data_list [0 ], torch .Tensor ):
105+ return torch .cat (data_list )
106+ if isinstance (data_list [0 ], Data ):
107+ return self .dataset .create_graph_batch (data_list )
108+ raise RuntimeError ("Data must be Tensors or LabelTensor or pyG Data" )
109+
95110 def __call__ (self , batch ):
96111 return self .callable_function (batch )
97112
@@ -125,7 +140,7 @@ def __init__(self,
125140 batch_size = None ,
126141 shuffle = True ,
127142 repeat = False ,
128- automatic_batching = False ,
143+ automatic_batching = None ,
129144 num_workers = 0 ,
130145 pin_memory = False ,
131146 ):
@@ -158,15 +173,35 @@ def __init__(self,
158173 logging .debug ('Start initialization of Pina DataModule' )
159174 logging .info ('Start initialization of Pina DataModule' )
160175 super ().__init__ ()
161- self .automatic_batching = automatic_batching
176+
177+ # Store fixed attributes
162178 self .batch_size = batch_size
163179 self .shuffle = shuffle
164180 self .repeat = repeat
181+ self .automatic_batching = automatic_batching
182+ if batch_size is None and num_workers != 0 :
183+ warnings .warn (
184+ "Setting num_workers when batch_size is None has no effect on "
185+ "the DataLoading process." )
186+ self .num_workers = 0
187+ else :
188+ self .num_workers = num_workers
189+ if batch_size is None and pin_memory :
190+ warnings .warn ("Setting pin_memory to True has no effect when "
191+ "batch_size is None." )
192+ self .pin_memory = False
193+ else :
194+ self .pin_memory = pin_memory
195+
196+ # Collect data
197+ collector = Collector (problem )
198+ collector .store_fixed_data ()
199+ collector .store_sample_domains ()
165200
166201 # Check if the splits are correct
167202 self ._check_slit_sizes (train_size , test_size , val_size , predict_size )
168203
169- # Begin Data splitting
204+ # Split input data into subsets
170205 splits_dict = {}
171206 if train_size > 0 :
172207 splits_dict ['train' ] = train_size
@@ -188,19 +223,6 @@ def __init__(self,
188223 self .predict_dataset = None
189224 else :
190225 self .predict_dataloader = super ().predict_dataloader
191-
192- collector = Collector (problem )
193- collector .store_fixed_data ()
194- collector .store_sample_domains ()
195- if batch_size is None and num_workers != 0 :
196- warnings .warn (
197- "Setting num_workers when batch_size is None has no effect on "
198- "the DataLoading process." )
199- if batch_size is None and pin_memory :
200- warnings .warn ("Setting pin_memory to True has no effect when "
201- "batch_size is None." )
202- self .num_workers = num_workers
203- self .pin_memory = pin_memory
204226 self .collector_splits = self ._create_splits (collector , splits_dict )
205227 self .transfer_batch_to_device = self ._transfer_batch_to_device
206228
@@ -316,10 +338,10 @@ def _create_dataloader(self, split, dataset):
316338 if self .batch_size is not None :
317339 sampler = PinaSampler (dataset , shuffle )
318340 if self .automatic_batching :
319- collate = Collator (self .find_max_conditions_lengths (split ))
320-
341+ collate = Collator (self .find_max_conditions_lengths (split ),
342+ dataset = dataset )
321343 else :
322- collate = Collator (None , dataset )
344+ collate = Collator (None , dataset = dataset )
323345 return DataLoader (dataset , self .batch_size ,
324346 collate_fn = collate , sampler = sampler ,
325347 num_workers = self .num_workers )
0 commit comments