|
4 | 4 | import pandas as pd |
5 | 5 | import torch |
6 | 6 | from lightning import LightningModule, Trainer |
| 7 | +from lightning.fabric.utilities.data import _set_sampler_epoch |
7 | 8 | from lightning.fabric.utilities.types import _PATH |
8 | 9 | from lightning.pytorch.loggers import WandbLogger |
| 10 | +from lightning.pytorch.loops.fit_loop import _FitLoop |
| 11 | +from lightning.pytorch.trainer import call |
9 | 12 | from torch.nn.utils.rnn import pad_sequence |
10 | 13 |
|
11 | 14 | from chebai.loggers.custom import CustomLogger |
@@ -39,6 +42,9 @@ def __init__(self, *args, **kwargs): |
39 | 42 | log_kwargs[log_key] = log_value |
40 | 43 | self.logger.log_hyperparams(log_kwargs) |
41 | 44 |
|
| 45 | + # use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops) |
| 46 | + self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs) |
| 47 | + |
42 | 48 | def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: |
43 | 49 | """ |
44 | 50 | Resolves logging arguments, handling nested structures such as lists and complex objects. |
@@ -147,3 +153,35 @@ def log_dir(self) -> Optional[str]: |
147 | 153 |
|
148 | 154 | dirpath = self.strategy.broadcast(dirpath) |
149 | 155 | return dirpath |
| 156 | + |
| 157 | + |
| 158 | +class LoadDataLaterFitLoop(_FitLoop): |
| 159 | + |
| 160 | + def on_advance_start(self) -> None: |
| 161 | + """Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary |
| 162 | + so that the dataloaders can get information from the model. For example: The on_train_epoch_start |
| 163 | + hook sets the curr_epoch attribute of the PubChemBatched dataset. With the Lightning configuration, |
| 164 | + the dataloaders would always load batch 0 first, run an epoch, then get the epoch number (usually 0, |
| 165 | + unless resuming from a checkpoint), then load batch 0 again (or some other batch). With this |
| 166 | + implementation, the dataloaders are setup after the epoch number is set, so that the correct |
| 167 | + batch is loaded.""" |
| 168 | + trainer = self.trainer |
| 169 | + |
| 170 | + # update the epoch value for all samplers |
| 171 | + assert self._combined_loader is not None |
| 172 | + for i, dl in enumerate(self._combined_loader.flattened): |
| 173 | + _set_sampler_epoch(dl, self.epoch_progress.current.processed) |
| 174 | + |
| 175 | + self.restarted |
| 176 | + if not self.restarted_mid_epoch and not self.restarted_on_epoch_end: |
| 177 | + if not self.restarted_on_epoch_start: |
| 178 | + self.epoch_progress.increment_ready() |
| 179 | + |
| 180 | + call._call_callback_hooks(trainer, "on_train_epoch_start") |
| 181 | + call._call_lightning_module_hook(trainer, "on_train_epoch_start") |
| 182 | + |
| 183 | + self.epoch_progress.increment_started() |
| 184 | + |
| 185 | + # this is usually at the front of advance_start, but here we need it at the end |
| 186 | + # might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs` |
| 187 | + self.setup_data() |
0 commit comments