|
1 | 1 | import logging |
| 2 | +import warnings |
2 | 3 | from lightning.pytorch import LightningDataModule |
3 | 4 | import torch |
4 | 5 | from ..label_tensor import LabelTensor |
|
8 | 9 | from .dataset import PinaDatasetFactory |
9 | 10 | from ..collector import Collector |
10 | 11 |
|
| 12 | + |
11 | 13 | class DummyDataloader: |
12 | 14 | """" |
13 | 15 | Dummy dataloader used when batch size is None. It callects all the data |
@@ -57,7 +59,7 @@ def __init__(self, max_conditions_lengths, dataset=None): |
57 | 59 | self.max_conditions_lengths = max_conditions_lengths |
58 | 60 | self.callable_function = self._collate_custom_dataloader if \ |
59 | 61 | max_conditions_lengths is None else ( |
60 | | - self._collate_standard_dataloader) |
| 62 | + self._collate_standard_dataloader) |
61 | 63 | self.dataset = dataset |
62 | 64 |
|
63 | 65 | def _collate_custom_dataloader(self, batch): |
@@ -95,7 +97,7 @@ def __call__(self, batch): |
95 | 97 |
|
96 | 98 |
|
97 | 99 | class PinaSampler: |
98 | | - def __new__(self, dataset, batch_size, shuffle, automatic_batching): |
| 100 | + def __new__(cls, dataset, shuffle): |
99 | 101 |
|
100 | 102 | if (torch.distributed.is_available() and |
101 | 103 | torch.distributed.is_initialized()): |
@@ -123,15 +125,35 @@ def __init__(self, |
123 | 125 | batch_size=None, |
124 | 126 | shuffle=True, |
125 | 127 | repeat=False, |
126 | | - automatic_batching=False |
| 128 | + automatic_batching=False, |
| 129 | + num_workers=0, |
| 130 | + pin_memory=False, |
127 | 131 | ): |
128 | 132 | """ |
129 | | - Initialize the object, creating dataset based on input problem |
130 | | - :param problem: Problem where data are defined |
131 | | - :param train_size: number/percentage of elements in train split |
132 | | - :param test_size: number/percentage of elements in test split |
133 | | - :param val_size: number/percentage of elements in evaluation split |
134 | | - :param batch_size: batch size used for training |
| 133 | + Initialize the object, creating datasets based on the input problem. |
| 134 | +
|
| 135 | + :param problem: The problem defining the dataset. |
| 136 | + :type problem: AbstractProblem |
| 137 | + :param train_size: Fraction or number of elements in the training split. |
| 138 | + :type train_size: float |
| 139 | + :param test_size: Fraction or number of elements in the test split. |
| 140 | + :type test_size: float |
| 141 | + :param val_size: Fraction or number of elements in the validation split. |
| 142 | + :type val_size: float |
| 143 | + :param predict_size: Fraction or number of elements in the prediction split. |
| 144 | + :type predict_size: float |
| 145 | + :param batch_size: Batch size used for training. If None, the entire dataset is used per batch. |
| 146 | + :type batch_size: int or None |
| 147 | + :param shuffle: Whether to shuffle the dataset before splitting. |
| 148 | + :type shuffle: bool |
| 149 | + :param repeat: Whether to repeat the dataset indefinitely. |
| 150 | + :type repeat: bool |
| 151 | + :param automatic_batching: Whether to enable automatic batching. |
| 152 | + :type automatic_batching: bool |
| 153 | + :param num_workers: Number of worker threads for data loading. Default 0 (serial loading) |
| 154 | + :type num_workers: int |
| 155 | + :param pin_memory: Whether to use pinned memory for faster data transfer to GPU. (Default False) |
| 156 | + :type pin_memory: bool |
135 | 157 | """ |
136 | 158 | logging.debug('Start initialization of Pina DataModule') |
137 | 159 | logging.info('Start initialization of Pina DataModule') |
@@ -170,6 +192,15 @@ def __init__(self, |
170 | 192 | collector = Collector(problem) |
171 | 193 | collector.store_fixed_data() |
172 | 194 | collector.store_sample_domains() |
| 195 | + if batch_size is None and num_workers != 0: |
| 196 | + warnings.warn( |
| 197 | + "Setting num_workers when batch_size is None has no effect on " |
| 198 | + "the DataLoading process.") |
| 199 | + if batch_size is None and pin_memory: |
| 200 | + warnings.warn("Setting pin_memory to True has no effect when " |
| 201 | + "batch_size is None.") |
| 202 | + self.num_workers = num_workers |
| 203 | + self.pin_memory = pin_memory |
173 | 204 | self.collector_splits = self._create_splits(collector, splits_dict) |
174 | 205 | self.transfer_batch_to_device = self._transfer_batch_to_device |
175 | 206 |
|
@@ -271,20 +302,27 @@ def _apply_shuffle(condition_dict, len_data): |
271 | 302 | dataset_dict[key].update({condition_name: data}) |
272 | 303 | return dataset_dict |
273 | 304 |
|
274 | | - |
275 | 305 | def _create_dataloader(self, split, dataset): |
276 | 306 | shuffle = self.shuffle if split == 'train' else False |
| 307 | + # Suppress the warning about num_workers. |
| 308 | + # In many cases, especially for PINNs, serial data loading can outperform parallel data loading. |
| 309 | + warnings.filterwarnings( |
| 310 | + "ignore", |
| 311 | + message=( |
| 312 | + r"The '(train|val|test)_dataloader' does not have many workers which may be a bottleneck."), |
| 313 | + module="lightning.pytorch.trainer.connectors.data_connector" |
| 314 | + ) |
277 | 315 | # Use custom batching (good if batch size is large) |
278 | 316 | if self.batch_size is not None: |
279 | | - sampler = PinaSampler(dataset, self.batch_size, |
280 | | - shuffle, self.automatic_batching) |
| 317 | + sampler = PinaSampler(dataset, shuffle) |
281 | 318 | if self.automatic_batching: |
282 | 319 | collate = Collator(self.find_max_conditions_lengths(split)) |
283 | 320 |
|
284 | 321 | else: |
285 | 322 | collate = Collator(None, dataset) |
286 | 323 | return DataLoader(dataset, self.batch_size, |
287 | | - collate_fn=collate, sampler=sampler) |
| 324 | + collate_fn=collate, sampler=sampler, |
| 325 | + num_workers=self.num_workers) |
288 | 326 | dataloader = DummyDataloader(dataset) |
289 | 327 | dataloader.dataset = self._transfer_batch_to_device( |
290 | 328 | dataloader.dataset, self.trainer.strategy.root_device, 0) |
|
0 commit comments