Save checkpoint with specific monitor criteria #11129
-
Hello everyone, I'm currently implementing a Wasserstain type of GAN using Gradient Penalty. I want to save the checkpoints monitoring the negative critic loss, which starts from low values, increases to higher values in the first epochs and then decreases reaching almost 0. A plot of this loss can be seen in the paper: https://arxiv.org/pdf/1704.00028.pdf The problem is that if I use ModelCheckpoint and set the monitor parameter to negative critic_loss and mode = 'min', it basically saves the first epoch only. However I don't want to consider the training start epochs, when the negative loss increase, but only the epochs when the loss decrease. I'm currently using multi-gpu training How can I implement this? Should I override the function on_train_epoch_end and save there the checkpoints, after checking the above criteria? Or should I use a lightning Callback? If so how can I acces to the monitored values? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Thanks to @tchaton on the slack community I solved the issue overriding the class WGANModelCheckpoint(ModelCheckpoint):
def __init__(self,
dirpath: Optional[Union[str, Path]] = None,
filename: Optional[str] = None,
monitor: Optional[str] = None,
verbose: bool = False,
save_last: Optional[bool] = None,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = "min",
auto_insert_metric_name: bool = True,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
every_n_epochs: Optional[int] = None,
save_on_train_epoch_end: Optional[bool] = None,
period: Optional[int] = None,
every_n_val_epochs: Optional[int] = None):
super().__init__(
dirpath=dirpath,
filename=filename,
monitor=monitor,
verbose=verbose,
save_last=save_last,
save_top_k=save_top_k,
save_weights_only=save_weights_only,
mode=mode,
auto_insert_metric_name=auto_insert_metric_name,
every_n_train_steps=every_n_train_steps,
train_time_interval=train_time_interval,
every_n_epochs=every_n_epochs,
save_on_train_epoch_end=save_on_train_epoch_end,
period=period,
every_n_val_epochs=every_n_val_epochs)
self.is_monitoring_on = False
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None) -> None:
"""Save a checkpoint at the end of the training epoch."""
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
trainer.fit_loop.global_step -= 1
if (
not self._should_skip_saving_checkpoint(trainer)
and self._save_on_train_epoch_end
and self._every_n_epochs > 0
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
and (self.is_monitoring_on or self.monitor_can_start(trainer, pl_module))
):
self.save_checkpoint(trainer)
trainer.fit_loop.global_step += 1
def monitor_can_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> bool:
"""Let start monitoring only after the loss curve start increasing"""
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1)
current = monitor_candidates.get(self.monitor)
# Check if the critic loss is increasing (the network is starting to
# train)
if trainer.current_epoch > 0 and pl_module.previous_metric < current:
self.is_monitoring_on = True
pl_module.previous_metric = current.detach().clone()
return self.is_monitoring_on The function |
Beta Was this translation helpful? Give feedback.
Thanks to @tchaton on the slack community I solved the issue overriding the
ModelCheckpoint
class.In the
on_train_epoch_end
I've added a new check that follow the above conditions, as such: