Inheritance and save_hyperparameters
#9509
-
Hello Lightning folks! Suppose I have a base model class that I'd like to inherit from as follows: import pytorch_lightning as pl
class ParentModel(pl.LightningModule):
def __init__(
self,
lr: float = 0.001,
):
super(ParentModel, self).__init__()
self.lr = lr
class ChildModel(ParentModel):
def __init__(
self,
lr: float = 0.005,
loss: str = "mse",
):
super(ParentModel, self).__init__()
self.lr = lr
self.loss = loss I would like to be able to access the hyperparameters of class ChildModel(ParentModel):
def __init__(
self,
lr: float = 0.005,
loss: str = "mse",
):
super(ParentModel, self).__init__()
self.lr = lr
self.loss = loss
self.save_hyperparameters() However, I would like to avoid the need to call One idea I have in mind is something like a Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You could do something like this: import pytorch_lightning as pl
class ParentModel(pl.LightningModule):
def __init__(
self,
lr: float = 0.001,
**kwargs
):
super(ParentModel, self).__init__()
self.save_hyperparameters()
self.lr = lr
class ChildModel(ParentModel):
def __init__(
self,
lr: float = 0.005,
loss: str = "mse",
):
super(ParentModel, self).__init__(lr=lr, loss=loss)
self.loss = loss That would save all hparams passed to the parent model (including the ones passed through the kwargs). If you want to go one step further, you could also include the following there: for k, v in kwargs.items():
setattr(self, k, v) which sets all attributes that are passed through kwargs automatically as model attributes. |
Beta Was this translation helpful? Give feedback.
You could do something like this:
That would save all hparams passed to the parent model (including the ones passed through the kwargs). If you want to go one step further, you could also include the following there: