How to save additional variables while checkpointing? #13080
Answered
by
rohitgr7
arunpatro-meesho
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
-
In vanilla pytorch, I save like this: if v_loss < best_val_loss:
print("Found better model, saving")
model_save_path = f"models/{sscat_id}/{attribute}/ts={time_stamp}/best.pth"
best_val_loss = v_loss
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optim_state_dict": optimizer.state_dict(),
"report": report,
"label_encoder_dict": label_encoder_dict,
"inverse_label_encoder_dict": {v: k for k, v in label_encoder_dict.items()},
"weighted_f1": weighted_f1,
},
model_save_path, How can I save extra keys like |
Beta Was this translation helpful? Give feedback.
Answered by
rohitgr7
May 16, 2022
Replies: 1 comment
-
you can use on_save_checkpoint hook inside LightningModule. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
arunpatro-meesho
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
you can use on_save_checkpoint hook inside LightningModule.