11import logging
22from lightning .pytorch import LightningDataModule
3- import math
43import torch
54from ..label_tensor import LabelTensor
65from torch .utils .data import DataLoader , BatchSampler , SequentialSampler , \
109from ..collector import Collector
1110
1211class DummyDataloader :
13- def __init__ (self , dataset , device ):
14- self .dataset = dataset .get_all_data ()
12+ """"
13+ Dummy dataloader used when batch size is None. It callects all the data
14+ in self.dataset and returns it when it is called a single batch.
15+ """
16+
17+ def __init__ (self , dataset ):
18+ """
19+ param dataset: The dataset object to be processed.
20+ :notes:
21+ - **Distributed Environment**:
22+ - Divides the dataset across processes using the
23+ rank and world size.
24+ - Fetches only the portion of data corresponding to
25+ the current process.
26+ - **Non-Distributed Environment**:
27+ - Fetches the entire dataset.
28+ """
29+ if (torch .distributed .is_available () and
30+ torch .distributed .is_initialized ()):
31+ rank = torch .distributed .get_rank ()
32+ world_size = torch .distributed .get_world_size ()
33+ if len (dataset ) < world_size :
34+ raise RuntimeError (
35+ "Dimension of the dataset smaller than world size."
36+ " Increase the size of the partition or use a single GPU" )
37+ idx , i = [], rank
38+ while i < len (dataset ):
39+ idx .append (i )
40+ i += world_size
41+ self .dataset = dataset .fetch_from_idx_list (idx )
42+ else :
43+ self .dataset = dataset .get_all_data ()
1544
1645 def __iter__ (self ):
1746 return self
@@ -50,7 +79,7 @@ def _collate_standard_dataloader(self, batch):
5079 for arg in condition_args :
5180 data_list = [batch [idx ][condition_name ][arg ] for idx in range (
5281 min (len (batch ),
53- self .max_conditions_lengths [condition_name ]))]
82+ self .max_conditions_lengths [condition_name ]))]
5483 if isinstance (data_list [0 ], LabelTensor ):
5584 single_cond_dict [arg ] = LabelTensor .stack (data_list )
5685 elif isinstance (data_list [0 ], torch .Tensor ):
@@ -61,7 +90,6 @@ def _collate_standard_dataloader(self, batch):
6190 batch_dict [condition_name ] = single_cond_dict
6291 return batch_dict
6392
64-
6593 def __call__ (self , batch ):
6694 return self .callable_function (batch )
6795
@@ -99,6 +127,7 @@ def __init__(self,
99127 ):
100128 """
101129 Initialize the object, creating dataset based on input problem
130+ :param problem: Problem where data are defined
102131 :param train_size: number/percentage of elements in train split
103132 :param test_size: number/percentage of elements in test split
104133 :param val_size: number/percentage of elements in evaluation split
@@ -112,6 +141,9 @@ def __init__(self,
112141 self .shuffle = shuffle
113142 self .repeat = repeat
114143
144+ # Check if the splits are correct
145+ self ._check_slit_sizes (train_size , test_size , val_size , predict_size )
146+
115147 # Begin Data splitting
116148 splits_dict = {}
117149 if train_size > 0 :
@@ -179,23 +211,28 @@ def _split_condition(condition_dict, splits_dict):
179211 len_condition = len (condition_dict ['input_points' ])
180212
181213 lengths = [
182- int (math . floor ( len_condition * length ) ) for length in
214+ int (len_condition * length ) for length in
183215 splits_dict .values ()
184216 ]
185217
186218 remainder = len_condition - sum (lengths )
187219 for i in range (remainder ):
188220 lengths [i % len (lengths )] += 1
189- splits_dict = {k : v for k , v in zip (splits_dict .keys (), lengths )
221+
222+ splits_dict = {k : max (1 , v ) for k , v in zip (splits_dict .keys (), lengths )
190223 }
191224 to_return_dict = {}
192225 offset = 0
226+
193227 for stage , stage_len in splits_dict .items ():
194228 to_return_dict [stage ] = {k : v [offset :offset + stage_len ]
195229 for k , v in condition_dict .items () if
196230 k != 'equation'
197231 # Equations are NEVER dataloaded
198232 }
233+ if offset + stage_len > len_condition :
234+ offset = len_condition - 1
235+ continue
199236 offset += stage_len
200237 return to_return_dict
201238
@@ -234,6 +271,26 @@ def _apply_shuffle(condition_dict, len_data):
234271 dataset_dict [key ].update ({condition_name : data })
235272 return dataset_dict
236273
274+
275+ def _create_dataloader (self , split , dataset ):
276+ shuffle = self .shuffle if split == 'train' else False
277+ # Use custom batching (good if batch size is large)
278+ if self .batch_size is not None :
279+ sampler = PinaSampler (dataset , self .batch_size ,
280+ shuffle , self .automatic_batching )
281+ if self .automatic_batching :
282+ collate = Collator (self .find_max_conditions_lengths (split ))
283+
284+ else :
285+ collate = Collator (None , dataset )
286+ return DataLoader (dataset , self .batch_size ,
287+ collate_fn = collate , sampler = sampler )
288+ dataloader = DummyDataloader (dataset )
289+ dataloader .dataset = self ._transfer_batch_to_device (
290+ dataloader .dataset , self .trainer .strategy .root_device , 0 )
291+ self .transfer_batch_to_device = self ._transfer_batch_to_device_dummy
292+ return dataloader
293+
237294 def find_max_conditions_lengths (self , split ):
238295 max_conditions_lengths = {}
239296 for k , v in self .collector_splits [split ].items ():
@@ -250,60 +307,28 @@ def val_dataloader(self):
250307 """
251308 Create the validation dataloader
252309 """
253- # Use custom batching (good if batch size is large)
254- if self .batch_size is not None :
255- sampler = PinaSampler (self .val_dataset , self .batch_size ,
256- self .shuffle , self .automatic_batching )
257- if self .automatic_batching :
258- collate = Collator (self .find_max_conditions_lengths ('val' ))
259- else :
260- collate = Collator (None , self .val_dataset )
261- return DataLoader (self .val_dataset , self .batch_size ,
262- collate_fn = collate , sampler = sampler )
263- dataloader = DummyDataloader (self .val_dataset ,
264- self .trainer .strategy .root_device )
265- dataloader .dataset = self ._transfer_batch_to_device (dataloader .dataset ,
266- self .trainer .strategy .root_device ,
267- 0 )
268- self .transfer_batch_to_device = self ._transfer_batch_to_device_dummy
269- return dataloader
310+ return self ._create_dataloader ('val' , self .val_dataset )
270311
271312 def train_dataloader (self ):
272313 """
273314 Create the training dataloader
274315 """
275- # Use custom batching (good if batch size is large)
276- if self .batch_size is not None :
277- sampler = PinaSampler (self .train_dataset , self .batch_size ,
278- self .shuffle , self .automatic_batching )
279- if self .automatic_batching :
280- collate = Collator (self .find_max_conditions_lengths ('train' ))
281-
282- else :
283- collate = Collator (None , self .train_dataset )
284- return DataLoader (self .train_dataset , self .batch_size ,
285- collate_fn = collate , sampler = sampler )
286- dataloader = DummyDataloader (self .train_dataset ,
287- self .trainer .strategy .root_device )
288- dataloader .dataset = self ._transfer_batch_to_device (dataloader .dataset ,
289- self .trainer .strategy .root_device ,
290- 0 )
291- self .transfer_batch_to_device = self ._transfer_batch_to_device_dummy
292- return dataloader
316+ return self ._create_dataloader ('train' , self .train_dataset )
293317
294318 def test_dataloader (self ):
295319 """
296320 Create the testing dataloader
297321 """
298- raise NotImplementedError ( "Test dataloader not implemented" )
322+ return self . _create_dataloader ( 'test' , self . test_dataset )
299323
300324 def predict_dataloader (self ):
301325 """
302326 Create the prediction dataloader
303327 """
304328 raise NotImplementedError ("Predict dataloader not implemented" )
305329
306- def _transfer_batch_to_device_dummy (self , batch , device , dataloader_idx ):
330+ @staticmethod
331+ def _transfer_batch_to_device_dummy (batch , device , dataloader_idx ):
307332 return batch
308333
309334 def _transfer_batch_to_device (self , batch , device , dataloader_idx ):
@@ -312,10 +337,34 @@ def _transfer_batch_to_device(self, batch, device, dataloader_idx):
312337 training loop and is used to transfer the batch to the device.
313338 """
314339 batch = [
315- (k , super ( LightningDataModule , self ). transfer_batch_to_device ( v ,
316- device ,
317- dataloader_idx ))
340+ (k ,
341+ super ( LightningDataModule , self ). transfer_batch_to_device (
342+ v , device , dataloader_idx ))
318343 for k , v in batch .items ()
319344 ]
320345
321346 return batch
347+
348+ @staticmethod
349+ def _check_slit_sizes (train_size , test_size , val_size , predict_size ):
350+ """
351+ Check if the splits are correct
352+ """
353+ if train_size < 0 or test_size < 0 or val_size < 0 or predict_size < 0 :
354+ raise ValueError ("The splits must be positive" )
355+ if abs (train_size + test_size + val_size + predict_size - 1 ) > 1e-6 :
356+ raise ValueError ("The sum of the splits must be 1" )
357+
358+ @property
359+ def input_points (self ):
360+ """
361+ # TODO
362+ """
363+ to_return = {}
364+ if hasattr (self , "train_dataset" ) and self .train_dataset is not None :
365+ to_return ["train" ] = self .train_dataset .input_points
366+ if hasattr (self , "val_dataset" ) and self .val_dataset is not None :
367+ to_return ["val" ] = self .val_dataset .input_points
368+ if hasattr (self , "test_dataset" ) and self .test_dataset is not None :
369+ to_return = self .test_dataset .input_points
370+ return to_return
0 commit comments