Skip to content

Commit 3df1993

Browse files
committed
improve docstring
1 parent 7845b5d commit 3df1993

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

src/lightning/pytorch/callbacks/finetuning.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,46 @@ class BackboneFinetuning(BaseFinetuning):
354354
355355
Example::
356356
357-
>>> from lightning.pytorch import Trainer
357+
>>> import torch
358+
>>> import torch.nn as nn
359+
>>> from lightning.pytorch import LightningModule, Trainer
358360
>>> from lightning.pytorch.callbacks import BackboneFinetuning
361+
>>> import torchvision.models as models
362+
>>>
363+
>>> class TransferLearningModel(LightningModule):
364+
... def __init__(self, num_classes=10):
365+
... super().__init__()
366+
... # REQUIRED: Your model must have a 'backbone' attribute
367+
... self.backbone = models.resnet50(weights="DEFAULT")
368+
... # Remove the final classification layer from backbone
369+
... self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
370+
...
371+
... # Add your task-specific head
372+
... self.head = nn.Sequential(
373+
... nn.Flatten(),
374+
... nn.Linear(2048, 512),
375+
... nn.ReLU(),
376+
... nn.Linear(512, num_classes)
377+
... )
378+
...
379+
... def forward(self, x):
380+
... # Extract features with backbone
381+
... features = self.backbone(x)
382+
... # Classify with head
383+
... return self.head(features)
384+
...
385+
... def configure_optimizers(self):
386+
... # Initially only optimize the head - backbone will be added by callback
387+
... return torch.optim.Adam(self.head.parameters(), lr=1e-3)
388+
...
389+
>>> # Setup the callback
359390
>>> multiplicative = lambda epoch: 1.5
360-
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
391+
>>> backbone_finetuning = BackboneFinetuning(
392+
... unfreeze_backbone_at_epoch=10, # Start unfreezing at epoch 10
393+
... lambda_func=multiplicative, # Gradually increase backbone LR
394+
... backbone_initial_ratio_lr=0.1, # Start backbone at 10% of head LR
395+
... )
396+
>>> model = TransferLearningModel()
361397
>>> trainer = Trainer(callbacks=[backbone_finetuning])
362398
363399
"""

0 commit comments

Comments
 (0)