Skip to content

Activation config for callbacks #20494

@JohnHerry

Description

@JohnHerry

Description & Motivation

My situation: I am using a ModelCheckpoint callback, My configuation is like follows:

callbacks:
  - class_path: pytorch_lightning.callbacks.ModelCheckpoint
     init_args:
        monitor: val_loss
        filename: "checkpoint_{epoch}_{step}_{vall_loss}.ckpt"
        save_top_k: 55
        every_n_train_steps: 1000

My traning have a warmup stage that is over 1000 steps, catastrophicly, the lowest monitor metric is just stand at the warmup stage. certainly this value appear there when the model is far from stable. So I want to config the ModelCheckpoint callback do activated after warmup stage, or it is activated when a certain of training steps passed. Here in the example, I can do as follows:

class MyModelCheckpoint(pytorch_lightning.callbacks.ModelCheckpoint):
      def __init__(self, dirpath, filename, ....,  active_after_step=20000):
            super(MyModelCheckpoint, self).__init__(dirpath, filename, ...)
            self.activte_after_step = active_after_step

      def _should_skip_saving_checkpoint(self, trainer):
           if trainer.global_steps < self.activate_after_step:
               return True
           return super(MyModelCheckpoint, self)._should_skip_saving_checkpoint(trainer)

but this is not a good desgin, and can not slove similar problems on other callbacks. So, Is there any root desigin to support activation condition conf on each callbacks?

If there has been and I did not noticed that, say sorry for my careless.

Pitch

No response

Alternatives

No response

Additional context

No response

cc @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions