diff --git a/docs/source-pytorch/advanced/transfer_learning.rst b/docs/source-pytorch/advanced/transfer_learning.rst index 50a65870b1572..94df157096549 100644 --- a/docs/source-pytorch/advanced/transfer_learning.rst +++ b/docs/source-pytorch/advanced/transfer_learning.rst @@ -126,3 +126,135 @@ Here's a model that uses `Huggingface transformers = self.unfreeze_at_epoch: + layers_to_unfreeze = min( + self.layers_per_epoch, + len(list(pl_module.backbone.children())) + ) + + # Unfreeze from the top layers down + backbone_children = list(pl_module.backbone.children()) + for layer in backbone_children[-layers_to_unfreeze:]: + self.unfreeze_and_add_param_group( + layer, optimizer, lr=1e-4 + ) diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index cec83fee0f4d7..da24387a82a8d 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -31,8 +31,13 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE from lightning.pytorch.utilities.rank_zero import rank_zero_warn +if not _TORCHVISION_AVAILABLE: + __doctest_skip__ = ["BackboneFinetuning"] + + log = logging.getLogger(__name__) @@ -354,10 +359,46 @@ class BackboneFinetuning(BaseFinetuning): Example:: - >>> from lightning.pytorch import Trainer + >>> import torch + >>> import torch.nn as nn + >>> from lightning.pytorch import LightningModule, Trainer >>> from lightning.pytorch.callbacks import BackboneFinetuning + >>> import torchvision.models as models + >>> + >>> class TransferLearningModel(LightningModule): + ... def __init__(self, num_classes=10): + ... super().__init__() + ... # REQUIRED: Your model must have a 'backbone' attribute + ... self.backbone = models.resnet50(weights="DEFAULT") + ... # Remove the final classification layer from backbone + ... self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) + ... + ... # Add your task-specific head + ... self.head = nn.Sequential( + ... nn.Flatten(), + ... nn.Linear(2048, 512), + ... nn.ReLU(), + ... nn.Linear(512, num_classes) + ... ) + ... + ... def forward(self, x): + ... # Extract features with backbone + ... features = self.backbone(x) + ... # Classify with head + ... return self.head(features) + ... + ... def configure_optimizers(self): + ... # Initially only optimize the head - backbone will be added by callback + ... return torch.optim.Adam(self.head.parameters(), lr=1e-3) + ... + >>> # Setup the callback >>> multiplicative = lambda epoch: 1.5 - >>> backbone_finetuning = BackboneFinetuning(200, multiplicative) + >>> backbone_finetuning = BackboneFinetuning( + ... unfreeze_backbone_at_epoch=10, # Start unfreezing at epoch 10 + ... lambda_func=multiplicative, # Gradually increase backbone LR + ... backbone_initial_ratio_lr=0.1, # Start backbone at 10% of head LR + ... ) + >>> model = TransferLearningModel() >>> trainer = Trainer(callbacks=[backbone_finetuning]) """