Skip to content
Discussion options

You must be logged in to vote

You could do something like this:

import pytorch_lightning as pl


class ParentModel(pl.LightningModule):
    def __init__(
        self,
        lr: float = 0.001,
        **kwargs
    ):
        super(ParentModel, self).__init__()
        self.save_hyperparameters()
        self.lr = lr

class ChildModel(ParentModel):
    def __init__(
        self,
        lr: float = 0.005,
        loss: str = "mse",
    ):
        super(ParentModel, self).__init__(lr=lr, loss=loss)
        self.loss = loss

That would save all hparams passed to the parent model (including the ones passed through the kwargs). If you want to go one step further, you could also include the following there:

for k, v in kw…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@lewtun
Comment options

Answer selected by lewtun
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment