@@ -126,3 +126,136 @@ Here's a model that uses `Huggingface transformers <https://github.com/huggingfa
126
126
h_cls = h[:, 0]
127
127
logits = self.W(h_cls)
128
128
return logits, attn
129
+
130
+ ----
131
+
132
+ ***********************************
133
+ Automated Finetuning with Callbacks
134
+ ***********************************
135
+
136
+ PyTorch Lightning provides the :class: `~lightning.pytorch.callbacks.BackboneFinetuning ` callback to automate
137
+ the finetuning process. This callback gradually unfreezes your model's backbone during training. This is particularly
138
+ useful when working with large pretrained models, as it allows you to start training with a frozen backbone and
139
+ then progressively unfreeze layers to fine-tune the model.
140
+
141
+ The :class: `~lightning.pytorch.callbacks.BackboneFinetuning ` callback expects your model to have a specific structure:
142
+
143
+ .. testcode ::
144
+
145
+ class MyModel(LightningModule):
146
+ def __init__(self):
147
+ super().__init__()
148
+
149
+ # REQUIRED: Your model must have a 'backbone' attribute
150
+ # This should be the pretrained part you want to finetune
151
+ self.backbone = some_pretrained_model
152
+
153
+ # Your task-specific layers (head, classifier, etc.)
154
+ self.head = nn.Linear(backbone_features, num_classes)
155
+
156
+ def configure_optimizers(self):
157
+ # Only optimize the head initially - backbone will be added automatically
158
+ return torch.optim.Adam(self.head.parameters(), lr=1e-3)
159
+
160
+ ************************************
161
+ Example: Computer Vision with ResNet
162
+ ************************************
163
+
164
+ Here's a complete example showing how to use :class: `~lightning.pytorch.callbacks.BackboneFinetuning `
165
+ for computer vision:
166
+
167
+ .. testcode ::
168
+ :skipif: not _TORCHVISION_AVAILABLE
169
+
170
+ import torch
171
+ import torch.nn as nn
172
+ import torchvision.models as models
173
+ from lightning.pytorch import LightningModule, Trainer
174
+ from lightning.pytorch.callbacks import BackboneFinetuning
175
+
176
+
177
+ class ResNetClassifier(LightningModule):
178
+ def __init__(self, num_classes=10, learning_rate=1e-3):
179
+ super().__init__()
180
+ self.save_hyperparameters()
181
+
182
+ # Create backbone from pretrained ResNet
183
+ resnet = models.resnet50(weights="DEFAULT")
184
+ # Remove the final classification layer
185
+ self.backbone = nn.Sequential(*list(resnet.children())[:-1])
186
+
187
+ # Add custom classification head
188
+ self.head = nn.Sequential(
189
+ nn.Flatten(),
190
+ nn.Linear(resnet.fc.in_features, 512),
191
+ nn.ReLU(),
192
+ nn.Dropout(0.2),
193
+ nn.Linear(512, num_classes)
194
+ )
195
+
196
+ def forward(self, x):
197
+ # Extract features with backbone
198
+ features = self.backbone(x)
199
+ # Classify with head
200
+ return self.head(features)
201
+
202
+ def training_step(self, batch, batch_idx):
203
+ x, y = batch
204
+ y_hat = self(x)
205
+ loss = nn.functional.cross_entropy(y_hat, y)
206
+ self.log('train_loss', loss)
207
+ return loss
208
+
209
+ def configure_optimizers(self):
210
+ # Initially only train the head - backbone will be added by callback
211
+ return torch.optim.Adam(self.head.parameters(), lr=self.hparams.learning_rate)
212
+
213
+
214
+ # Setup the finetuning callback
215
+ backbone_finetuning = BackboneFinetuning(
216
+ unfreeze_backbone_at_epoch=10, # Start unfreezing backbone at epoch 10
217
+ lambda_func=lambda epoch: 1.5, # Gradually increase backbone learning rate
218
+ backbone_initial_ratio_lr=0.1, # Backbone starts at 10% of head learning rate
219
+ should_align=True, # Align rates when backbone rate reaches head rate
220
+ verbose=True # Print learning rates during training
221
+ )
222
+
223
+ model = ResNetClassifier()
224
+ trainer = Trainer(callbacks=[backbone_finetuning], max_epochs=20)
225
+
226
+ ****************************
227
+ Custom Finetuning Strategies
228
+ ****************************
229
+
230
+ For more control, you can create custom finetuning strategies by subclassing
231
+ :class: `~lightning.pytorch.callbacks.BaseFinetuning `:
232
+
233
+ .. testcode ::
234
+
235
+ from lightning.pytorch.callbacks.finetuning import BaseFinetuning
236
+
237
+
238
+ class CustomFinetuning(BaseFinetuning):
239
+ def __init__(self, unfreeze_at_epoch=5, layers_per_epoch=2):
240
+ super().__init__()
241
+ self.unfreeze_at_epoch = unfreeze_at_epoch
242
+ self.layers_per_epoch = layers_per_epoch
243
+
244
+ def freeze_before_training(self, pl_module):
245
+ # Freeze the entire backbone initially
246
+ self.freeze(pl_module.backbone)
247
+
248
+ def finetune_function(self, pl_module, epoch, optimizer):
249
+ # Gradually unfreeze layers
250
+ if epoch >= self.unfreeze_at_epoch:
251
+ layers_to_unfreeze = min(
252
+ self.layers_per_epoch,
253
+ len(list(pl_module.backbone.children()))
254
+ )
255
+
256
+ # Unfreeze from the top layers down
257
+ backbone_children = list(pl_module.backbone.children())
258
+ for layer in backbone_children[-layers_to_unfreeze:]:
259
+ self.unfreeze_and_add_param_group(
260
+ layer, optimizer, lr=1e-4
261
+ )
0 commit comments