using EMA with model checkpoints #11276
-
I'm trying to incorporate the pytorch_ema library into the PL training loop. I found one topic relating to using pytorch_ema in lightning in this discussion thread, but how would this work if i want to save a model checkpoint based on the EMA weights? for example if i want to save the model weights using just pytorch, i could do something like
so that i save the smoothed weights, but restore the original weights to the model so it doesn't affect training one workaround i can think of is to create my own model saving logic in the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
you can replace the model state_dict inside the checkpoint class LitModel(LightningModule):
...
def on_save_checkpoint(self, checkpoint):
with ema.average_parameters():
checkpoint['state_dict'] = self.state_dict() |
Beta Was this translation helpful? Give feedback.
you can replace the model state_dict inside the checkpoint