diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 863a3a4a7e939..60b9593cddd49 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186)) +- Fixed bug that prevented `BackboneFinetuning` from being used together with `LearningRateFinder` ([#21224](https://github.com/Lightning-AI/pytorch-lightning/pull/21224)) + + --- ## [2.5.5] - 2025-09-05 diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index cec83fee0f4d7..f4bd8009ad13a 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -108,12 +108,14 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: # restore the param_groups created during the previous training. if self._restarting: - named_parameters = dict(pl_module.named_parameters()) - for opt_idx, optimizer in enumerate(trainer.optimizers): - param_groups = self._apply_mapping_to_param_groups( - self._internal_optimizer_metadata[opt_idx], named_parameters - ) - optimizer.param_groups = param_groups + if self._internal_optimizer_metadata: + named_parameters = dict(pl_module.named_parameters()) + for opt_idx, optimizer in enumerate(trainer.optimizers): + if opt_idx in self._internal_optimizer_metadata: + param_groups = self._apply_mapping_to_param_groups( + self._internal_optimizer_metadata[opt_idx], named_parameters + ) + optimizer.param_groups = param_groups self._restarting = False @staticmethod diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 81352ebe256ef..8c6b3f6f539d3 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -25,6 +25,7 @@ from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.callbacks.finetuning import BackboneFinetuning from lightning.pytorch.callbacks.lr_finder import LearningRateFinder from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.tuner.lr_finder import _LRFinder @@ -800,3 +801,47 @@ def configure_optimizers(self): assert len(lr_find_checkpoints) == 0, ( f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}" ) + + +def test_lr_finder_with_backbone_finetuning_callback(tmp_path): + """Test that lr_find works correctly with BackboneFinetuning callback.""" + + class ModelWithBackbone(BoringModel): + def __init__(self): + super().__init__() + # Create a simple backbone-head architecture + self.backbone = torch.nn.Sequential(torch.nn.Linear(32, 16), torch.nn.ReLU(), torch.nn.Linear(16, 8)) + self.head = torch.nn.Linear(8, 2) + self.learning_rate = 1e-3 + + def forward(self, x): + backbone_features = self.backbone(x) + return self.head(backbone_features) + + def configure_optimizers(self): + return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) + + model = ModelWithBackbone() + backbone_finetuning = BackboneFinetuning(unfreeze_backbone_at_epoch=1) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=3, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + callbacks=[backbone_finetuning], + ) + + tuner = Tuner(trainer) + lr_finder = tuner.lr_find(model, num_training=5) + + assert lr_finder is not None + assert hasattr(lr_finder, "results") + assert len(lr_finder.results) > 0 + trainer.fit(model) + + # Check that backbone was unfrozen at the correct epoch + for param in model.backbone.parameters(): + assert param.requires_grad, "Backbone parameters should be unfrozen after epoch 1"