-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
My situation: I am using a ModelCheckpoint callback, My configuation is like follows:
callbacks:
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
filename: "checkpoint_{epoch}_{step}_{vall_loss}.ckpt"
save_top_k: 55
every_n_train_steps: 1000
My traning have a warmup stage that is over 1000 steps, catastrophicly, the lowest monitor metric is just stand at the warmup stage. certainly this value appear there when the model is far from stable. So I want to config the ModelCheckpoint callback do activated after warmup stage, or it is activated when a certain of training steps passed. Here in the example, I can do as follows:
class MyModelCheckpoint(pytorch_lightning.callbacks.ModelCheckpoint):
def __init__(self, dirpath, filename, ...., active_after_step=20000):
super(MyModelCheckpoint, self).__init__(dirpath, filename, ...)
self.activte_after_step = active_after_step
def _should_skip_saving_checkpoint(self, trainer):
if trainer.global_steps < self.activate_after_step:
return True
return super(MyModelCheckpoint, self)._should_skip_saving_checkpoint(trainer)
but this is not a good desgin, and can not slove similar problems on other callbacks. So, Is there any root desigin to support activation condition conf on each callbacks?
If there has been and I did not noticed that, say sorry for my careless.
Pitch
No response
Alternatives
No response
Additional context
No response
cc @Borda