Is there a way to save only part of the Lightning sub-modules to the checkpoint file? #10808
-
I'll explain: Let's say that I have two nn.modules inside my main LightningModule, but one of them is frozen, i.e. doesn't learn during the training but is only used for inferencing during training (requires_grad is False in this module) and I would like to avoid saving the state_dictionray of this static (frozen) module to the checkpoint file. In plain PyTorch I'd probably filter manually the state_dictionray fields of the frozen module before the saving. A simple toy example for clarification.
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
you have to do that here too. def on_save_checkpoint(self, checkpoint):
checkpoint['state_dict'] <- remove/pop keys from here |
Beta Was this translation helpful? Give feedback.
you have to do that here too.
within lightning you can override
on_save_checkpoint
hook of LightningModule.