Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))


- Fixed bug that prevented `BackboneFinetuning` from being used together with `LearningRateFinder` ([#21224](https://github.com/Lightning-AI/pytorch-lightning/pull/21224))


---

## [2.5.5] - 2025-09-05
Expand Down
14 changes: 8 additions & 6 deletions src/lightning/pytorch/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# restore the param_groups created during the previous training.
if self._restarting:
named_parameters = dict(pl_module.named_parameters())
for opt_idx, optimizer in enumerate(trainer.optimizers):
param_groups = self._apply_mapping_to_param_groups(
self._internal_optimizer_metadata[opt_idx], named_parameters
)
optimizer.param_groups = param_groups
if self._internal_optimizer_metadata:
named_parameters = dict(pl_module.named_parameters())
for opt_idx, optimizer in enumerate(trainer.optimizers):
if opt_idx in self._internal_optimizer_metadata:
param_groups = self._apply_mapping_to_param_groups(
self._internal_optimizer_metadata[opt_idx], named_parameters
)
optimizer.param_groups = param_groups
self._restarting = False

@staticmethod
Expand Down
45 changes: 45 additions & 0 deletions tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.tuner.lr_finder import _LRFinder
Expand Down Expand Up @@ -800,3 +801,47 @@ def configure_optimizers(self):
assert len(lr_find_checkpoints) == 0, (
f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
)


def test_lr_finder_with_backbone_finetuning_callback(tmp_path):
"""Test that lr_find works correctly with BackboneFinetuning callback."""

class ModelWithBackbone(BoringModel):
def __init__(self):
super().__init__()
# Create a simple backbone-head architecture
self.backbone = torch.nn.Sequential(torch.nn.Linear(32, 16), torch.nn.ReLU(), torch.nn.Linear(16, 8))
self.head = torch.nn.Linear(8, 2)
self.learning_rate = 1e-3

def forward(self, x):
backbone_features = self.backbone(x)
return self.head(backbone_features)

def configure_optimizers(self):
return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)

model = ModelWithBackbone()
backbone_finetuning = BackboneFinetuning(unfreeze_backbone_at_epoch=1)

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=3,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
logger=False,
callbacks=[backbone_finetuning],
)

tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model, num_training=5)

assert lr_finder is not None
assert hasattr(lr_finder, "results")
assert len(lr_finder.results) > 0
trainer.fit(model)

# Check that backbone was unfrozen at the correct epoch
for param in model.backbone.parameters():
assert param.requires_grad, "Backbone parameters should be unfrozen after epoch 1"
Loading