Skip to content

Commit 5c84ec7

Browse files
committed
add custom fit loop for custom hook handling
1 parent 85656da commit 5c84ec7

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

chebai/trainer/CustomTrainer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import pandas as pd
55
import torch
66
from lightning import LightningModule, Trainer
7+
from lightning.fabric.utilities.data import _set_sampler_epoch
78
from lightning.fabric.utilities.types import _PATH
89
from lightning.pytorch.loggers import WandbLogger
10+
from lightning.pytorch.loops.fit_loop import _FitLoop
11+
from lightning.pytorch.trainer import call
912
from torch.nn.utils.rnn import pad_sequence
1013

1114
from chebai.loggers.custom import CustomLogger
@@ -39,6 +42,9 @@ def __init__(self, *args, **kwargs):
3942
log_kwargs[log_key] = log_value
4043
self.logger.log_hyperparams(log_kwargs)
4144

45+
# use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops)
46+
self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs)
47+
4248
def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]:
4349
"""
4450
Resolves logging arguments, handling nested structures such as lists and complex objects.
@@ -147,3 +153,35 @@ def log_dir(self) -> Optional[str]:
147153

148154
dirpath = self.strategy.broadcast(dirpath)
149155
return dirpath
156+
157+
158+
class LoadDataLaterFitLoop(_FitLoop):
159+
160+
def on_advance_start(self) -> None:
161+
"""Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary
162+
so that the dataloaders can get information from the model. For example: The on_train_epoch_start
163+
hook sets the curr_epoch attribute of the PubChemBatched dataset. With the Lightning configuration,
164+
the dataloaders would always load batch 0 first, run an epoch, then get the epoch number (usually 0,
165+
unless resuming from a checkpoint), then load batch 0 again (or some other batch). With this
166+
implementation, the dataloaders are setup after the epoch number is set, so that the correct
167+
batch is loaded."""
168+
trainer = self.trainer
169+
170+
# update the epoch value for all samplers
171+
assert self._combined_loader is not None
172+
for i, dl in enumerate(self._combined_loader.flattened):
173+
_set_sampler_epoch(dl, self.epoch_progress.current.processed)
174+
175+
self.restarted
176+
if not self.restarted_mid_epoch and not self.restarted_on_epoch_end:
177+
if not self.restarted_on_epoch_start:
178+
self.epoch_progress.increment_ready()
179+
180+
call._call_callback_hooks(trainer, "on_train_epoch_start")
181+
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
182+
183+
self.epoch_progress.increment_started()
184+
185+
# this is usually at the front of advance_start, but here we need it at the end
186+
# might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs`
187+
self.setup_data()

0 commit comments

Comments
 (0)