SWA callback with torch.save state_dict() #11997
Unanswered
maxmatical
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment
-
Could you please show how you were able to retrieve the model with the averaged weights, if you were able to? I can't seem to figure that out |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
i'm trying to experiment with SWA callback when training pytorch models using the lightning trainer. one problem i'm running into is my lightningmodule is structured like
i tried training with and without swa callback (using the same seed for 10 epochs so i know it's averaging models). but after loading the pytorch model and running evaluation, i'm getting the same metric for the pytorch model when training with/without SWA
does anyone know why this is the case?
edit: solved the issue. turns out the trainer was automatically loading the best module checkpoint, which doesn't contain the averaged weights
Beta Was this translation helpful? Give feedback.
All reactions