Saving hyperparameters vs checkpointing #15864
Unanswered
mfoglio
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am building an image classifier with multiple heads. Each head predict a categorial value.
Let's assume that we want to classify fruits and their color. We will have an head that predict the name of the fruit (
apple
,pear
,grapes
) and another head that predicts the color (red
,green
,yellow
).The labels of both the attributes (fruit name and fruit color) are stored in a configuration files. At runtime, the a
MultiOutputClassifier
(see below) is initialized withinMultiOutputClassifierModule
reflecting the number of attributes and their labels. In other words, the output of the network is not hardcoded, but it defined according to a configuration file.The class
MultiOutputClassifier
(not the instance) is passed to aLightningModule
during its initialization:I am doing this to speed up experimentation: I can define a different model architecture
MultiOutputClassifierV2
and pass it to the sameMultiOutputClassifierModule
. In brief, this structure allows me to try multiple models by simply passing them to the torch lightning module.Everything works great, but I have issues understanding which hyperparameters I should save.
If I simply use
self.save_hyperparameters(logger=False)
everything seems to work great: the checkpoint is saved, and I can load it withMultiOutputClassifierModule.load_from_checkpoint(checkpoint_file_path)
; however I get the warning:However, if I use
self.save_hyperparameters(ignore=['net'])
, then I cannot useMultiOutputClassifierModule.load_from_checkpoint(checkpoint_file_path)
because it will expect a parameternet
as well. However, it doesn't make sense to me to initialize pass the parameternet
aside since a specific network was initialized during the previous training based on a specific configuration file.My questions are:
self.save_hyperparameters()
exactly does? What python attributes is it saving exactly? Will these attributes be saved just at initialization or every time a checkpoint is saved? I am assuming it saves the kwargs received by the lightning module. Let me know if this is correct.self.save_hyperparameters()
seems to work better, but I don't understand why torch lightning suggests the other option (i.e.self.save_hyperparameters(ignore=['net'])
).an instance of
nn.Moduleis already saved during checkpointing
. If I understand correctly, when doingself.save_hyperparameters()
I am saving the initial (untrained)net
as an hyperparameter, and then I am saving again its trained weights inside thestate_dict
. Is this correct?Thank you
EDIT: is it possible that my problem arises from the fact that I have a kwarg
net
which should be a class, and then I have an attribute callednet
(same name) which instead is an instance of that class? I am wondering if this could be the reason torch lightning is raising the warning.Beta Was this translation helpful? Give feedback.
All reactions