Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions docs/source-pytorch/advanced/transfer_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,136 @@ Here's a model that uses `Huggingface transformers <https://github.com/huggingfa
h_cls = h[:, 0]
logits = self.W(h_cls)
return logits, attn

----

***********************************
Automated Finetuning with Callbacks
***********************************

PyTorch Lightning provides the :class:`~lightning.pytorch.callbacks.BackboneFinetuning` callback to automate
the finetuning process. This callback gradually unfreezes your model's backbone during training. This is particularly
useful when working with large pretrained models, as it allows you to start training with a frozen backbone and
then progressively unfreeze layers to fine-tune the model.

The :class:`~lightning.pytorch.callbacks.BackboneFinetuning` callback expects your model to have a specific structure:

.. testcode::

class MyModel(LightningModule):
def __init__(self):
super().__init__()

# REQUIRED: Your model must have a 'backbone' attribute
# This should be the pretrained part you want to finetune
self.backbone = some_pretrained_model

# Your task-specific layers (head, classifier, etc.)
self.head = nn.Linear(backbone_features, num_classes)

def configure_optimizers(self):
# Only optimize the head initially - backbone will be added automatically
return torch.optim.Adam(self.head.parameters(), lr=1e-3)

************************************
Example: Computer Vision with ResNet
************************************

Here's a complete example showing how to use :class:`~lightning.pytorch.callbacks.BackboneFinetuning`
for computer vision:

.. testcode::
:skipif: not _TORCHVISION_AVAILABLE

import torch
import torch.nn as nn
import torchvision.models as models
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BackboneFinetuning


class ResNetClassifier(LightningModule):
def __init__(self, num_classes=10, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()

# Create backbone from pretrained ResNet
resnet = models.resnet50(weights="DEFAULT")
# Remove the final classification layer
self.backbone = nn.Sequential(*list(resnet.children())[:-1])

# Add custom classification head
self.head = nn.Sequential(
nn.Flatten(),
nn.Linear(resnet.fc.in_features, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, num_classes)
)

def forward(self, x):
# Extract features with backbone
features = self.backbone(x)
# Classify with head
return self.head(features)

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss

def configure_optimizers(self):
# Initially only train the head - backbone will be added by callback
return torch.optim.Adam(self.head.parameters(), lr=self.hparams.learning_rate)


# Setup the finetuning callback
backbone_finetuning = BackboneFinetuning(
unfreeze_backbone_at_epoch=10, # Start unfreezing backbone at epoch 10
lambda_func=lambda epoch: 1.5, # Gradually increase backbone learning rate
backbone_initial_ratio_lr=0.1, # Backbone starts at 10% of head learning rate
should_align=True, # Align rates when backbone rate reaches head rate
verbose=True # Print learning rates during training
)

model = ResNetClassifier()
trainer = Trainer(callbacks=[backbone_finetuning], max_epochs=20)

****************************
Custom Finetuning Strategies
****************************

For more control, you can create custom finetuning strategies by subclassing
:class:`~lightning.pytorch.callbacks.BaseFinetuning`:

.. testcode::

from lightning.pytorch.callbacks.finetuning import BaseFinetuning


class CustomFinetuning(BaseFinetuning):
def __init__(self, unfreeze_at_epoch=5, layers_per_epoch=2):
super().__init__()
self.unfreeze_at_epoch = unfreeze_at_epoch
self.layers_per_epoch = layers_per_epoch

def freeze_before_training(self, pl_module):
# Freeze the entire backbone initially
self.freeze(pl_module.backbone)

def finetune_function(self, pl_module, epoch, optimizer):
# Gradually unfreeze layers
if epoch >= self.unfreeze_at_epoch:
layers_to_unfreeze = min(
self.layers_per_epoch,
len(list(pl_module.backbone.children()))
)

# Unfreeze from the top layers down
backbone_children = list(pl_module.backbone.children())
for layer in backbone_children[-layers_to_unfreeze:]:
self.unfreeze_and_add_param_group(
layer, optimizer, lr=1e-4
)
40 changes: 38 additions & 2 deletions src/lightning/pytorch/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,46 @@ class BackboneFinetuning(BaseFinetuning):

Example::

>>> from lightning.pytorch import Trainer
>>> import torch
>>> import torch.nn as nn
>>> from lightning.pytorch import LightningModule, Trainer
>>> from lightning.pytorch.callbacks import BackboneFinetuning
>>> import torchvision.models as models
>>>
>>> class TransferLearningModel(LightningModule):
... def __init__(self, num_classes=10):
... super().__init__()
... # REQUIRED: Your model must have a 'backbone' attribute
... self.backbone = models.resnet50(weights="DEFAULT")
... # Remove the final classification layer from backbone
... self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
...
... # Add your task-specific head
... self.head = nn.Sequential(
... nn.Flatten(),
... nn.Linear(2048, 512),
... nn.ReLU(),
... nn.Linear(512, num_classes)
... )
...
... def forward(self, x):
... # Extract features with backbone
... features = self.backbone(x)
... # Classify with head
... return self.head(features)
...
... def configure_optimizers(self):
... # Initially only optimize the head - backbone will be added by callback
... return torch.optim.Adam(self.head.parameters(), lr=1e-3)
...
>>> # Setup the callback
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> backbone_finetuning = BackboneFinetuning(
... unfreeze_backbone_at_epoch=10, # Start unfreezing at epoch 10
... lambda_func=multiplicative, # Gradually increase backbone LR
... backbone_initial_ratio_lr=0.1, # Start backbone at 10% of head LR
... )
>>> model = TransferLearningModel()
>>> trainer = Trainer(callbacks=[backbone_finetuning])

"""
Expand Down
Loading