@@ -354,10 +354,46 @@ class BackboneFinetuning(BaseFinetuning):
354
354
355
355
Example::
356
356
357
- >>> from lightning.pytorch import Trainer
357
+ >>> import torch
358
+ >>> import torch.nn as nn
359
+ >>> from lightning.pytorch import LightningModule, Trainer
358
360
>>> from lightning.pytorch.callbacks import BackboneFinetuning
361
+ >>> import torchvision.models as models
362
+ >>>
363
+ >>> class TransferLearningModel(LightningModule):
364
+ ... def __init__(self, num_classes=10):
365
+ ... super().__init__()
366
+ ... # REQUIRED: Your model must have a 'backbone' attribute
367
+ ... self.backbone = models.resnet50(weights="DEFAULT")
368
+ ... # Remove the final classification layer from backbone
369
+ ... self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
370
+ ...
371
+ ... # Add your task-specific head
372
+ ... self.head = nn.Sequential(
373
+ ... nn.Flatten(),
374
+ ... nn.Linear(2048, 512),
375
+ ... nn.ReLU(),
376
+ ... nn.Linear(512, num_classes)
377
+ ... )
378
+ ...
379
+ ... def forward(self, x):
380
+ ... # Extract features with backbone
381
+ ... features = self.backbone(x)
382
+ ... # Classify with head
383
+ ... return self.head(features)
384
+ ...
385
+ ... def configure_optimizers(self):
386
+ ... # Initially only optimize the head - backbone will be added by callback
387
+ ... return torch.optim.Adam(self.head.parameters(), lr=1e-3)
388
+ ...
389
+ >>> # Setup the callback
359
390
>>> multiplicative = lambda epoch: 1.5
360
- >>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
391
+ >>> backbone_finetuning = BackboneFinetuning(
392
+ ... unfreeze_backbone_at_epoch=10, # Start unfreezing at epoch 10
393
+ ... lambda_func=multiplicative, # Gradually increase backbone LR
394
+ ... backbone_initial_ratio_lr=0.1, # Start backbone at 10% of head LR
395
+ ... )
396
+ >>> model = TransferLearningModel()
361
397
>>> trainer = Trainer(callbacks=[backbone_finetuning])
362
398
363
399
"""
0 commit comments