Skip to content

Commit 7845b5d

Browse files
committed
add to documentation
1 parent b7ca4d3 commit 7845b5d

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

docs/source-pytorch/advanced/transfer_learning.rst

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,136 @@ Here's a model that uses `Huggingface transformers <https://github.com/huggingfa
126126
h_cls = h[:, 0]
127127
logits = self.W(h_cls)
128128
return logits, attn
129+
130+
----
131+
132+
***********************************
133+
Automated Finetuning with Callbacks
134+
***********************************
135+
136+
PyTorch Lightning provides the :class:`~lightning.pytorch.callbacks.BackboneFinetuning` callback to automate
137+
the finetuning process. This callback gradually unfreezes your model's backbone during training. This is particularly
138+
useful when working with large pretrained models, as it allows you to start training with a frozen backbone and
139+
then progressively unfreeze layers to fine-tune the model.
140+
141+
The :class:`~lightning.pytorch.callbacks.BackboneFinetuning` callback expects your model to have a specific structure:
142+
143+
.. testcode::
144+
145+
class MyModel(LightningModule):
146+
def __init__(self):
147+
super().__init__()
148+
149+
# REQUIRED: Your model must have a 'backbone' attribute
150+
# This should be the pretrained part you want to finetune
151+
self.backbone = some_pretrained_model
152+
153+
# Your task-specific layers (head, classifier, etc.)
154+
self.head = nn.Linear(backbone_features, num_classes)
155+
156+
def configure_optimizers(self):
157+
# Only optimize the head initially - backbone will be added automatically
158+
return torch.optim.Adam(self.head.parameters(), lr=1e-3)
159+
160+
************************************
161+
Example: Computer Vision with ResNet
162+
************************************
163+
164+
Here's a complete example showing how to use :class:`~lightning.pytorch.callbacks.BackboneFinetuning`
165+
for computer vision:
166+
167+
.. testcode::
168+
:skipif: not _TORCHVISION_AVAILABLE
169+
170+
import torch
171+
import torch.nn as nn
172+
import torchvision.models as models
173+
from lightning.pytorch import LightningModule, Trainer
174+
from lightning.pytorch.callbacks import BackboneFinetuning
175+
176+
177+
class ResNetClassifier(LightningModule):
178+
def __init__(self, num_classes=10, learning_rate=1e-3):
179+
super().__init__()
180+
self.save_hyperparameters()
181+
182+
# Create backbone from pretrained ResNet
183+
resnet = models.resnet50(weights="DEFAULT")
184+
# Remove the final classification layer
185+
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
186+
187+
# Add custom classification head
188+
self.head = nn.Sequential(
189+
nn.Flatten(),
190+
nn.Linear(resnet.fc.in_features, 512),
191+
nn.ReLU(),
192+
nn.Dropout(0.2),
193+
nn.Linear(512, num_classes)
194+
)
195+
196+
def forward(self, x):
197+
# Extract features with backbone
198+
features = self.backbone(x)
199+
# Classify with head
200+
return self.head(features)
201+
202+
def training_step(self, batch, batch_idx):
203+
x, y = batch
204+
y_hat = self(x)
205+
loss = nn.functional.cross_entropy(y_hat, y)
206+
self.log('train_loss', loss)
207+
return loss
208+
209+
def configure_optimizers(self):
210+
# Initially only train the head - backbone will be added by callback
211+
return torch.optim.Adam(self.head.parameters(), lr=self.hparams.learning_rate)
212+
213+
214+
# Setup the finetuning callback
215+
backbone_finetuning = BackboneFinetuning(
216+
unfreeze_backbone_at_epoch=10, # Start unfreezing backbone at epoch 10
217+
lambda_func=lambda epoch: 1.5, # Gradually increase backbone learning rate
218+
backbone_initial_ratio_lr=0.1, # Backbone starts at 10% of head learning rate
219+
should_align=True, # Align rates when backbone rate reaches head rate
220+
verbose=True # Print learning rates during training
221+
)
222+
223+
model = ResNetClassifier()
224+
trainer = Trainer(callbacks=[backbone_finetuning], max_epochs=20)
225+
226+
****************************
227+
Custom Finetuning Strategies
228+
****************************
229+
230+
For more control, you can create custom finetuning strategies by subclassing
231+
:class:`~lightning.pytorch.callbacks.BaseFinetuning`:
232+
233+
.. testcode::
234+
235+
from lightning.pytorch.callbacks.finetuning import BaseFinetuning
236+
237+
238+
class CustomFinetuning(BaseFinetuning):
239+
def __init__(self, unfreeze_at_epoch=5, layers_per_epoch=2):
240+
super().__init__()
241+
self.unfreeze_at_epoch = unfreeze_at_epoch
242+
self.layers_per_epoch = layers_per_epoch
243+
244+
def freeze_before_training(self, pl_module):
245+
# Freeze the entire backbone initially
246+
self.freeze(pl_module.backbone)
247+
248+
def finetune_function(self, pl_module, epoch, optimizer):
249+
# Gradually unfreeze layers
250+
if epoch >= self.unfreeze_at_epoch:
251+
layers_to_unfreeze = min(
252+
self.layers_per_epoch,
253+
len(list(pl_module.backbone.children()))
254+
)
255+
256+
# Unfreeze from the top layers down
257+
backbone_children = list(pl_module.backbone.children())
258+
for layer in backbone_children[-layers_to_unfreeze:]:
259+
self.unfreeze_and_add_param_group(
260+
layer, optimizer, lr=1e-4
261+
)

0 commit comments

Comments
 (0)