diff --git a/docs/source-pytorch/advanced/transfer_learning.rst b/docs/source-pytorch/advanced/transfer_learning.rst index 7f6af6ad5a56d..50a65870b1572 100644 --- a/docs/source-pytorch/advanced/transfer_learning.rst +++ b/docs/source-pytorch/advanced/transfer_learning.rst @@ -32,7 +32,7 @@ Let's use the `AutoEncoder` as a feature extractor in a separate model. class CIFAR10Classifier(LightningModule): def __init__(self): # init the pretrained LightningModule - self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH) + self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH).encoder self.feature_extractor.freeze() # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes