Skip to content

Commit 2866eca

Browse files
dario-cosciaGiovanniCanaliFilippoOlivo
authored
Update solvers (#434)
* Enable DDP training with batch_size=None and add validity check for split sizes * Refactoring SolverInterfaces (#435) * Solver update + weighting * Updating PINN for 0.2 * Modify GAROM + tests * Adding more versatile loggers * Disable compilation when running on Windows * Fix tests --------- Co-authored-by: Dario Coscia <[email protected]> Co-authored-by: giovanni <[email protected]> Co-authored-by: FilippoOlivo <[email protected]>
1 parent fb523d7 commit 2866eca

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2930
-4269
lines changed

pina/callbacks/processing_callbacks.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""PINA Callbacks Implementations"""
22

3-
from lightning.pytorch.core.module import LightningModule
4-
from lightning.pytorch.trainer.trainer import Trainer
53
import torch
64
import copy
75

@@ -16,66 +14,64 @@ class MetricTracker(Callback):
1614

1715
def __init__(self, metrics_to_track=None):
1816
"""
19-
PINA Implementation of a Lightning Callback for Metric Tracking.
17+
Lightning Callback for Metric Tracking.
2018
21-
This class provides functionality to track relevant metrics during
22-
the training process.
19+
Tracks specific metrics during the training process.
2320
24-
:ivar _collection: A list to store collected metrics after each
25-
training epoch.
21+
:ivar _collection: A list to store collected metrics after each epoch.
2622
27-
:param trainer: The trainer object managing the training process.
28-
:type trainer: pytorch_lightning.Trainer
29-
30-
:return: A dictionary containing aggregated metric values.
31-
:rtype: dict
32-
33-
Example:
34-
>>> tracker = MetricTracker()
35-
>>> # ... Perform training ...
36-
>>> metrics = tracker.metrics
23+
:param metrics_to_track: List of metrics to track. Defaults to train/val loss.
24+
:type metrics_to_track: list, optional
3725
"""
3826
super().__init__()
3927
self._collection = []
40-
if metrics_to_track is not None:
41-
metrics_to_track = ['train_loss_epoch', 'train_loss_step', 'val_loss']
42-
self.metrics_to_track = metrics_to_track
28+
# Default to tracking 'train_loss' and 'val_loss' if not specified
29+
self.metrics_to_track = metrics_to_track or ['train_loss', 'val_loss']
4330

4431
def on_train_epoch_end(self, trainer, pl_module):
4532
"""
4633
Collect and track metrics at the end of each training epoch.
4734
4835
:param trainer: The trainer object managing the training process.
4936
:type trainer: pytorch_lightning.Trainer
50-
:param pl_module: Placeholder argument.
37+
:param pl_module: The model being trained (not used here).
5138
"""
52-
super().on_train_epoch_end(trainer, pl_module)
39+
# Track metrics after the first epoch onwards
5340
if trainer.current_epoch > 0:
54-
self._collection.append(
55-
copy.deepcopy(trainer.logged_metrics)
56-
) # track them
41+
# Append only the tracked metrics to avoid unnecessary data
42+
tracked_metrics = {
43+
k: v for k, v in trainer.logged_metrics.items()
44+
if k in self.metrics_to_track
45+
}
46+
self._collection.append(copy.deepcopy(tracked_metrics))
5747

5848
@property
5949
def metrics(self):
6050
"""
61-
Aggregate collected metrics during training.
51+
Aggregate collected metrics over all epochs.
6252
6353
:return: A dictionary containing aggregated metric values.
6454
:rtype: dict
6555
"""
66-
common_keys = set.intersection(*map(set, self._collection))
67-
v = {
56+
if not self._collection:
57+
return {}
58+
59+
# Get intersection of keys across all collected dictionaries
60+
common_keys = set(self._collection[0]).intersection(*self._collection[1:])
61+
62+
# Stack the metric values for common keys and return
63+
return {
6864
k: torch.stack([dic[k] for dic in self._collection])
69-
for k in common_keys
65+
for k in common_keys if k in self.metrics_to_track
7066
}
71-
return v
67+
7268

7369

7470
class PINAProgressBar(TQDMProgressBar):
7571

7672
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
7773

78-
def __init__(self, metrics="val_loss", **kwargs):
74+
def __init__(self, metrics="val", **kwargs):
7975
"""
8076
PINA Implementation of a Lightning Callback for enriching the progress
8177
bar.
@@ -131,14 +127,6 @@ def get_metrics(self, trainer, model):
131127
pbar_metrics = {
132128
key: pbar_metrics[key] for key in self._sorted_metrics
133129
}
134-
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
135-
if duplicates:
136-
rank_zero_warn(
137-
f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"
138-
f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "
139-
" If this is undesired, change the name or override `get_metrics()` in the progress bar callback.",
140-
)
141-
142130
return {**standard_metrics, **pbar_metrics}
143131

144132
def on_fit_start(self, trainer, pl_module):
@@ -154,7 +142,7 @@ def on_fit_start(self, trainer, pl_module):
154142
for key in self._sorted_metrics:
155143
if (
156144
key not in trainer.solver.problem.conditions.keys()
157-
and key != "mean"
145+
and key != "train" and key != "val"
158146
):
159147
raise KeyError(f"Key '{key}' is not present in the dictionary")
160148
# add the loss pedix

pina/collector.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from . import LabelTensor
2-
from .utils import check_consistency, merge_tensors
1+
from .utils import check_consistency
32

43

54
class Collector:
@@ -8,11 +7,6 @@ def __init__(self, problem):
87
# creating a hook between collector and problem
98
self.problem = problem
109

11-
# this variable is used to store the data in the form:
12-
# {'[condition_name]' :
13-
# {'input_points' : Tensor,
14-
# '[equation/output_points/conditional_variables]': Tensor}
15-
# }
1610
# those variables are used for the dataloading
1711
self._data_collections = {name: {} for name in self.problem.conditions}
1812
self.conditions_name = {

pina/data/data_module.py

Lines changed: 96 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
from lightning.pytorch import LightningDataModule
3-
import math
43
import torch
54
from ..label_tensor import LabelTensor
65
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \
@@ -10,8 +9,38 @@
109
from ..collector import Collector
1110

1211
class DummyDataloader:
13-
def __init__(self, dataset, device):
14-
self.dataset = dataset.get_all_data()
12+
""""
13+
Dummy dataloader used when batch size is None. It callects all the data
14+
in self.dataset and returns it when it is called a single batch.
15+
"""
16+
17+
def __init__(self, dataset):
18+
"""
19+
param dataset: The dataset object to be processed.
20+
:notes:
21+
- **Distributed Environment**:
22+
- Divides the dataset across processes using the
23+
rank and world size.
24+
- Fetches only the portion of data corresponding to
25+
the current process.
26+
- **Non-Distributed Environment**:
27+
- Fetches the entire dataset.
28+
"""
29+
if (torch.distributed.is_available() and
30+
torch.distributed.is_initialized()):
31+
rank = torch.distributed.get_rank()
32+
world_size = torch.distributed.get_world_size()
33+
if len(dataset) < world_size:
34+
raise RuntimeError(
35+
"Dimension of the dataset smaller than world size."
36+
" Increase the size of the partition or use a single GPU")
37+
idx, i = [], rank
38+
while i < len(dataset):
39+
idx.append(i)
40+
i += world_size
41+
self.dataset = dataset.fetch_from_idx_list(idx)
42+
else:
43+
self.dataset = dataset.get_all_data()
1544

1645
def __iter__(self):
1746
return self
@@ -50,7 +79,7 @@ def _collate_standard_dataloader(self, batch):
5079
for arg in condition_args:
5180
data_list = [batch[idx][condition_name][arg] for idx in range(
5281
min(len(batch),
53-
self.max_conditions_lengths[condition_name]))]
82+
self.max_conditions_lengths[condition_name]))]
5483
if isinstance(data_list[0], LabelTensor):
5584
single_cond_dict[arg] = LabelTensor.stack(data_list)
5685
elif isinstance(data_list[0], torch.Tensor):
@@ -61,7 +90,6 @@ def _collate_standard_dataloader(self, batch):
6190
batch_dict[condition_name] = single_cond_dict
6291
return batch_dict
6392

64-
6593
def __call__(self, batch):
6694
return self.callable_function(batch)
6795

@@ -99,6 +127,7 @@ def __init__(self,
99127
):
100128
"""
101129
Initialize the object, creating dataset based on input problem
130+
:param problem: Problem where data are defined
102131
:param train_size: number/percentage of elements in train split
103132
:param test_size: number/percentage of elements in test split
104133
:param val_size: number/percentage of elements in evaluation split
@@ -112,6 +141,9 @@ def __init__(self,
112141
self.shuffle = shuffle
113142
self.repeat = repeat
114143

144+
# Check if the splits are correct
145+
self._check_slit_sizes(train_size, test_size, val_size, predict_size)
146+
115147
# Begin Data splitting
116148
splits_dict = {}
117149
if train_size > 0:
@@ -179,23 +211,28 @@ def _split_condition(condition_dict, splits_dict):
179211
len_condition = len(condition_dict['input_points'])
180212

181213
lengths = [
182-
int(math.floor(len_condition * length)) for length in
214+
int(len_condition * length) for length in
183215
splits_dict.values()
184216
]
185217

186218
remainder = len_condition - sum(lengths)
187219
for i in range(remainder):
188220
lengths[i % len(lengths)] += 1
189-
splits_dict = {k: v for k, v in zip(splits_dict.keys(), lengths)
221+
222+
splits_dict = {k: max(1, v) for k, v in zip(splits_dict.keys(), lengths)
190223
}
191224
to_return_dict = {}
192225
offset = 0
226+
193227
for stage, stage_len in splits_dict.items():
194228
to_return_dict[stage] = {k: v[offset:offset + stage_len]
195229
for k, v in condition_dict.items() if
196230
k != 'equation'
197231
# Equations are NEVER dataloaded
198232
}
233+
if offset + stage_len > len_condition:
234+
offset = len_condition - 1
235+
continue
199236
offset += stage_len
200237
return to_return_dict
201238

@@ -234,6 +271,26 @@ def _apply_shuffle(condition_dict, len_data):
234271
dataset_dict[key].update({condition_name: data})
235272
return dataset_dict
236273

274+
275+
def _create_dataloader(self, split, dataset):
276+
shuffle = self.shuffle if split == 'train' else False
277+
# Use custom batching (good if batch size is large)
278+
if self.batch_size is not None:
279+
sampler = PinaSampler(dataset, self.batch_size,
280+
shuffle, self.automatic_batching)
281+
if self.automatic_batching:
282+
collate = Collator(self.find_max_conditions_lengths(split))
283+
284+
else:
285+
collate = Collator(None, dataset)
286+
return DataLoader(dataset, self.batch_size,
287+
collate_fn=collate, sampler=sampler)
288+
dataloader = DummyDataloader(dataset)
289+
dataloader.dataset = self._transfer_batch_to_device(
290+
dataloader.dataset, self.trainer.strategy.root_device, 0)
291+
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
292+
return dataloader
293+
237294
def find_max_conditions_lengths(self, split):
238295
max_conditions_lengths = {}
239296
for k, v in self.collector_splits[split].items():
@@ -250,60 +307,28 @@ def val_dataloader(self):
250307
"""
251308
Create the validation dataloader
252309
"""
253-
# Use custom batching (good if batch size is large)
254-
if self.batch_size is not None:
255-
sampler = PinaSampler(self.val_dataset, self.batch_size,
256-
self.shuffle, self.automatic_batching)
257-
if self.automatic_batching:
258-
collate = Collator(self.find_max_conditions_lengths('val'))
259-
else:
260-
collate = Collator(None, self.val_dataset)
261-
return DataLoader(self.val_dataset, self.batch_size,
262-
collate_fn=collate, sampler=sampler)
263-
dataloader = DummyDataloader(self.val_dataset,
264-
self.trainer.strategy.root_device)
265-
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
266-
self.trainer.strategy.root_device,
267-
0)
268-
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
269-
return dataloader
310+
return self._create_dataloader('val', self.val_dataset)
270311

271312
def train_dataloader(self):
272313
"""
273314
Create the training dataloader
274315
"""
275-
# Use custom batching (good if batch size is large)
276-
if self.batch_size is not None:
277-
sampler = PinaSampler(self.train_dataset, self.batch_size,
278-
self.shuffle, self.automatic_batching)
279-
if self.automatic_batching:
280-
collate = Collator(self.find_max_conditions_lengths('train'))
281-
282-
else:
283-
collate = Collator(None, self.train_dataset)
284-
return DataLoader(self.train_dataset, self.batch_size,
285-
collate_fn=collate, sampler=sampler)
286-
dataloader = DummyDataloader(self.train_dataset,
287-
self.trainer.strategy.root_device)
288-
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
289-
self.trainer.strategy.root_device,
290-
0)
291-
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
292-
return dataloader
316+
return self._create_dataloader('train', self.train_dataset)
293317

294318
def test_dataloader(self):
295319
"""
296320
Create the testing dataloader
297321
"""
298-
raise NotImplementedError("Test dataloader not implemented")
322+
return self._create_dataloader('test', self.test_dataset)
299323

300324
def predict_dataloader(self):
301325
"""
302326
Create the prediction dataloader
303327
"""
304328
raise NotImplementedError("Predict dataloader not implemented")
305329

306-
def _transfer_batch_to_device_dummy(self, batch, device, dataloader_idx):
330+
@staticmethod
331+
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
307332
return batch
308333

309334
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
@@ -312,10 +337,34 @@ def _transfer_batch_to_device(self, batch, device, dataloader_idx):
312337
training loop and is used to transfer the batch to the device.
313338
"""
314339
batch = [
315-
(k, super(LightningDataModule, self).transfer_batch_to_device(v,
316-
device,
317-
dataloader_idx))
340+
(k,
341+
super(LightningDataModule, self).transfer_batch_to_device(
342+
v, device, dataloader_idx))
318343
for k, v in batch.items()
319344
]
320345

321346
return batch
347+
348+
@staticmethod
349+
def _check_slit_sizes(train_size, test_size, val_size, predict_size):
350+
"""
351+
Check if the splits are correct
352+
"""
353+
if train_size < 0 or test_size < 0 or val_size < 0 or predict_size < 0:
354+
raise ValueError("The splits must be positive")
355+
if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6:
356+
raise ValueError("The sum of the splits must be 1")
357+
358+
@property
359+
def input_points(self):
360+
"""
361+
# TODO
362+
"""
363+
to_return = {}
364+
if hasattr(self, "train_dataset") and self.train_dataset is not None:
365+
to_return["train"] = self.train_dataset.input_points
366+
if hasattr(self, "val_dataset") and self.val_dataset is not None:
367+
to_return["val"] = self.val_dataset.input_points
368+
if hasattr(self, "test_dataset") and self.test_dataset is not None:
369+
to_return = self.test_dataset.input_points
370+
return to_return

0 commit comments

Comments
 (0)