Saving checkpoint, hparams & tfevents after training to separate folder #11779
Answered
by
rohitgr7
dispoth
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
-
Thanks to all the contributors of PyTorch Lightning for a fantastic product! I want to save a checkpoint, hparams & tfevents after training finishes. I have written this callback: class AfterTrainCheckpoint(pl.Callback):
"""
Callback for saving the checkpoint weights, hparams and tf.events after training finishes
"""
def __init__(self):
super().__init__()
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
print(f"Saving final checkpoint...")
# As we advance one step at end of training, we use `global_step - 1`
final_checkpoint_name = f"final_models/final_step_{trainer.global_step - 1}.ckpt"
final_hparams_name = f"final_models/final_step_{trainer.global_step - 1}.yaml"
trainer.save_checkpoint(final_checkpoint_name)
save_hparams_to_yaml(config_yaml=final_hparams_name, hparams=trainer.model.hparams)
|
Beta Was this translation helpful? Give feedback.
Answered by
rohitgr7
Feb 7, 2022
Replies: 1 comment 1 reply
-
hey @dispoth !!
|
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
dispoth
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
hey @dispoth !!
on_fit_end
instead, since the last checkpoint in the model checkpoint is saved in this hook, so it won't guarantee to have that ckpt when your callback calls it.trainer.log_dir
.on_train_end
andon_fit_end
.