Skip to content

Commit 543a54f

Browse files
committed
Merge branch 'master' into feat/device_name
2 parents 7bf3645 + 2337948 commit 543a54f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+491
-102
lines changed

.actions/assistant.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@
4646

4747

4848
class _RequirementWithComment(Requirement):
49-
strict_string = "# strict"
49+
strict_cmd = "strict"
5050

5151
def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:
5252
super().__init__(*args, **kwargs)
5353
self.comment = comment
5454
assert pip_argument is None or pip_argument # sanity check that it's not an empty str
5555
self.pip_argument = pip_argument
56-
self.strict = self.strict_string in comment.lower()
56+
self.strict = self.strict_cmd in comment.lower()
5757

5858
def adjust(self, unfreeze: str) -> str:
5959
"""Remove version restrictions unless they are strict.
@@ -62,25 +62,26 @@ def adjust(self, unfreeze: str) -> str:
6262
'arrow<=1.2.2,>=1.2.0'
6363
>>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# strict").adjust("none")
6464
'arrow<=1.2.2,>=1.2.0 # strict'
65-
>>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# my name").adjust("all")
66-
'arrow>=1.2.0'
65+
>>> _RequirementWithComment('arrow<=1.2.2,>=1.2.0; python_version >= "3.10"', comment="# my name").adjust("all")
66+
'arrow>=1.2.0; python_version >= "3.10"'
6767
>>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("all")
6868
'arrow<=1.2.2,>=1.2.0 # strict'
69-
>>> _RequirementWithComment("arrow").adjust("all")
70-
'arrow'
69+
>>> _RequirementWithComment('arrow; python_version >= "3.10"').adjust("all")
70+
'arrow; python_version >= "3.10"'
7171
>>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# cool").adjust("major")
7272
'arrow<2.0,>=1.2.0'
7373
>>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("major")
7474
'arrow<=1.2.2,>=1.2.0 # strict'
75-
>>> _RequirementWithComment("arrow>=1.2.0").adjust("major")
76-
'arrow>=1.2.0'
75+
>>> _RequirementWithComment('arrow>=1.2.0; python_version >= "3.10"').adjust("major")
76+
'arrow>=1.2.0; python_version >= "3.10"'
7777
>>> _RequirementWithComment("arrow").adjust("major")
7878
'arrow'
7979
8080
"""
8181
out = str(self)
8282
if self.strict:
83-
return f"{out} {self.strict_string}"
83+
return f"{out} # {self.strict_cmd}"
84+
8485
specs = [(spec.operator, spec.version) for spec in self.specifier]
8586
if unfreeze == "major":
8687
for operator, version in specs:

.github/workflows/_legacy-checkpoints.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
- name: Install uv and set Python version
6161
uses: astral-sh/setup-uv@v7
6262
with:
63-
python-version: "3.9"
63+
python-version: "3.10"
6464
# TODO: Avoid activating environment like this
6565
# see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment
6666
activate-environment: true

.github/workflows/ci-pkg-install.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
matrix:
4545
os: ["ubuntu-22.04", "macOS-14", "windows-2022"]
4646
pkg-name: ["fabric", "pytorch", "lightning", "notset"]
47-
python-version: ["3.9", "3.11"]
47+
python-version: ["3.10", "3.11"]
4848
steps:
4949
- uses: actions/checkout@v5
5050
- uses: actions/setup-python@v6

.github/workflows/ci-tests-fabric.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
- name: Install uv and set Python version
7474
uses: astral-sh/setup-uv@v7
7575
with:
76-
python-version: ${{ matrix.config.python-version || '3.9' }}
76+
python-version: ${{ matrix.config.python-version || '3.10' }}
7777
# TODO: Avoid activating environment like this
7878
# see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment
7979
activate-environment: true

.github/workflows/ci-tests-pytorch.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ jobs:
7878
- name: Install uv and set Python version
7979
uses: astral-sh/setup-uv@v7
8080
with:
81-
python-version: ${{ matrix.config.python-version || '3.9' }}
81+
python-version: ${{ matrix.config.python-version || '3.10' }}
8282
# TODO: Avoid activating environment like this
8383
# see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment
8484
activate-environment: true

.github/workflows/release-nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- uses: actions/checkout@v5
2727
- uses: actions/setup-python@v6
2828
with:
29-
python-version: 3.9
29+
python-version: "3.10"
3030

3131
- name: Convert actual version to nightly
3232
run: |

.github/workflows/release-pkg.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ defaults:
2424
env:
2525
FREEZE_REQUIREMENTS: 1
2626
TORCH_URL: "https://download.pytorch.org/whl/cpu/"
27-
PYTHON_VER: "3.9"
27+
PYTHON_VER: "3.10"
2828

2929
jobs:
3030
build-packages:

.readthedocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ sphinx:
3131
build:
3232
os: ubuntu-20.04
3333
tools:
34-
python: "3.9"
34+
python: "3.10"
3535
apt_packages:
3636
- texlive-latex-extra
3737
- dvipng

docs/source-pytorch/advanced/transfer_learning.rst

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,135 @@ Here's a model that uses `Huggingface transformers <https://github.com/huggingfa
126126
h_cls = h[:, 0]
127127
logits = self.W(h_cls)
128128
return logits, attn
129+
130+
----
131+
132+
***********************************
133+
Automated Finetuning with Callbacks
134+
***********************************
135+
136+
PyTorch Lightning provides the :class:`~lightning.pytorch.callbacks.BackboneFinetuning` callback to automate
137+
the finetuning process. This callback gradually unfreezes your model's backbone during training. This is particularly
138+
useful when working with large pretrained models, as it allows you to start training with a frozen backbone and
139+
then progressively unfreeze layers to fine-tune the model.
140+
141+
The :class:`~lightning.pytorch.callbacks.BackboneFinetuning` callback expects your model to have a specific structure:
142+
143+
.. testcode::
144+
145+
class MyModel(LightningModule):
146+
def __init__(self):
147+
super().__init__()
148+
149+
# REQUIRED: Your model must have a 'backbone' attribute
150+
# This should be the pretrained part you want to finetune
151+
self.backbone = some_pretrained_model
152+
153+
# Your task-specific layers (head, classifier, etc.)
154+
self.head = nn.Linear(backbone_features, num_classes)
155+
156+
def configure_optimizers(self):
157+
# Only optimize the head initially - backbone will be added automatically
158+
return torch.optim.Adam(self.head.parameters(), lr=1e-3)
159+
160+
************************************
161+
Example: Computer Vision with ResNet
162+
************************************
163+
164+
Here's a complete example showing how to use :class:`~lightning.pytorch.callbacks.BackboneFinetuning`
165+
for computer vision:
166+
167+
.. code-block:: python
168+
169+
import torch
170+
import torch.nn as nn
171+
import torchvision.models as models
172+
from lightning.pytorch import LightningModule, Trainer
173+
from lightning.pytorch.callbacks import BackboneFinetuning
174+
175+
176+
class ResNetClassifier(LightningModule):
177+
def __init__(self, num_classes=10, learning_rate=1e-3):
178+
super().__init__()
179+
self.save_hyperparameters()
180+
181+
# Create backbone from pretrained ResNet
182+
resnet = models.resnet50(weights="DEFAULT")
183+
# Remove the final classification layer
184+
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
185+
186+
# Add custom classification head
187+
self.head = nn.Sequential(
188+
nn.Flatten(),
189+
nn.Linear(resnet.fc.in_features, 512),
190+
nn.ReLU(),
191+
nn.Dropout(0.2),
192+
nn.Linear(512, num_classes)
193+
)
194+
195+
def forward(self, x):
196+
# Extract features with backbone
197+
features = self.backbone(x)
198+
# Classify with head
199+
return self.head(features)
200+
201+
def training_step(self, batch, batch_idx):
202+
x, y = batch
203+
y_hat = self(x)
204+
loss = nn.functional.cross_entropy(y_hat, y)
205+
self.log('train_loss', loss)
206+
return loss
207+
208+
def configure_optimizers(self):
209+
# Initially only train the head - backbone will be added by callback
210+
return torch.optim.Adam(self.head.parameters(), lr=self.hparams.learning_rate)
211+
212+
213+
# Setup the finetuning callback
214+
backbone_finetuning = BackboneFinetuning(
215+
unfreeze_backbone_at_epoch=10, # Start unfreezing backbone at epoch 10
216+
lambda_func=lambda epoch: 1.5, # Gradually increase backbone learning rate
217+
backbone_initial_ratio_lr=0.1, # Backbone starts at 10% of head learning rate
218+
should_align=True, # Align rates when backbone rate reaches head rate
219+
verbose=True # Print learning rates during training
220+
)
221+
222+
model = ResNetClassifier()
223+
trainer = Trainer(callbacks=[backbone_finetuning], max_epochs=20)
224+
225+
****************************
226+
Custom Finetuning Strategies
227+
****************************
228+
229+
For more control, you can create custom finetuning strategies by subclassing
230+
:class:`~lightning.pytorch.callbacks.BaseFinetuning`:
231+
232+
.. testcode::
233+
234+
from lightning.pytorch.callbacks.finetuning import BaseFinetuning
235+
236+
237+
class CustomFinetuning(BaseFinetuning):
238+
def __init__(self, unfreeze_at_epoch=5, layers_per_epoch=2):
239+
super().__init__()
240+
self.unfreeze_at_epoch = unfreeze_at_epoch
241+
self.layers_per_epoch = layers_per_epoch
242+
243+
def freeze_before_training(self, pl_module):
244+
# Freeze the entire backbone initially
245+
self.freeze(pl_module.backbone)
246+
247+
def finetune_function(self, pl_module, epoch, optimizer):
248+
# Gradually unfreeze layers
249+
if epoch >= self.unfreeze_at_epoch:
250+
layers_to_unfreeze = min(
251+
self.layers_per_epoch,
252+
len(list(pl_module.backbone.children()))
253+
)
254+
255+
# Unfreeze from the top layers down
256+
backbone_children = list(pl_module.backbone.children())
257+
for layer in backbone_children[-layers_to_unfreeze:]:
258+
self.unfreeze_and_add_param_group(
259+
layer, optimizer, lr=1e-4
260+
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ filterwarnings = [
172172
# "error::DeprecationWarning",
173173
"error::FutureWarning",
174174
"ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated
175+
"ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning",
175176
]
176177
xfail_strict = true
177178
junit_duration_report = "call"

0 commit comments

Comments
 (0)