diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 3bfb609465a83..a1fdf0b8857fd 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -277,6 +277,14 @@ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called after ``loss.backward()`` and before optimizers are stepped.""" + def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. + + Useful when you need to make changes to the model before the optimizers are set up (e.g. freezing layers). + + """ + def on_before_optimizer_step( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer ) -> None: diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index cec83fee0f4d7..fd7776ebe7e67 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -104,6 +104,19 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: # compatibility to load from old checkpoints before PR #11887 self._internal_optimizer_metadata = state_dict # type: ignore[assignment] + @override + def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + # freeze the required modules before optimizers are created & after `configure_model` is run + self.freeze_before_training(pl_module) + + from lightning.pytorch.strategies import DeepSpeedStrategy + + if isinstance(trainer.strategy, DeepSpeedStrategy): + raise NotImplementedError( + "The Finetuning callback does not support running with the DeepSpeed strategy." + " Choose a different strategy or disable the callback." + ) + @override def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: # restore the param_groups created during the previous training. @@ -273,18 +286,6 @@ def unfreeze_and_add_param_group( if params: optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr}) - @override - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: - self.freeze_before_training(pl_module) - - from lightning.pytorch.strategies import DeepSpeedStrategy - - if isinstance(trainer.strategy, DeepSpeedStrategy): - raise NotImplementedError( - "The Finetuning callback does not support running with the DeepSpeed strategy." - " Choose a different strategy or disable the callback." - ) - @staticmethod def _apply_mapping_to_param_groups(param_groups: list[dict[str, Any]], mapping: dict) -> list[dict[str, Any]]: output = [] diff --git a/src/lightning/pytorch/callbacks/lambda_function.py b/src/lightning/pytorch/callbacks/lambda_function.py index f04b2d777deb3..f8348ccb72a53 100644 --- a/src/lightning/pytorch/callbacks/lambda_function.py +++ b/src/lightning/pytorch/callbacks/lambda_function.py @@ -69,6 +69,7 @@ def __init__( on_load_checkpoint: Optional[Callable] = None, on_before_backward: Optional[Callable] = None, on_after_backward: Optional[Callable] = None, + on_before_optimizer_setup: Optional[Callable] = None, on_before_optimizer_step: Optional[Callable] = None, on_before_zero_grad: Optional[Callable] = None, on_predict_start: Optional[Callable] = None, diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 0b0ab14244e38..c5439e1a7f181 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -295,6 +295,29 @@ def on_after_backward(self) -> None: """ + def on_before_optimizer_setup(self) -> None: + """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. + + This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created. + It’s particularly useful for callbacks such as + :class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen + prior to optimizer setup. + + This hook runs once in fit stage, after the model + has been fully instantiated by ``configure_model``. + + Example:: + + class MyFinetuneCallback(Callback): + def on_before_optimizer_setup(self, trainer, pl_module): + # freeze the backbone before optimizers are created + for param in pl_module.backbone.parameters(): + param.requires_grad = False + + """ + pass + def on_before_optimizer_step(self, optimizer: Optimizer) -> None: """Called before ``optimizer.step()``. diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 5768c507e2e3f..c0d6f18166652 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -989,6 +989,11 @@ def _run( log.debug(f"{self.__class__.__name__}: configuring model") call._call_configure_model(self) + # run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured + if self.state.fn == TrainerFn.FITTING: + call._call_callback_hooks(self, "on_before_optimizer_setup") + call._call_lightning_module_hook(self, "on_before_optimizer_setup") + # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index b2fecedd342ea..cfd07cef99474 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -431,3 +431,51 @@ def test_unsupported_strategies(tmp_path): trainer = Trainer(accelerator="cpu", strategy="deepspeed", callbacks=[callback]) with pytest.raises(NotImplementedError, match="does not support running with the DeepSpeed strategy"): callback.setup(trainer, model, stage=None) + + +def test_finetuning_with_configure_model(tmp_path): + """Test that BaseFinetuning works correctly with configure_model by ensuring freeze_before_training is called after + configure_model but before training starts.""" + + class TrackingFinetuningCallback(BaseFinetuning): + def __init__(self): + super().__init__() + + def freeze_before_training(self, pl_module): + assert hasattr(pl_module, "backbone"), "backbone should be configured before freezing" + self.freeze(pl_module.backbone) + + def finetune_function(self, pl_module, epoch, optimizer): + pass + + class TestModel(LightningModule): + def __init__(self): + super().__init__() + self.configure_model_called_count = 0 + + def configure_model(self): + self.backbone = nn.Linear(32, 32) + self.classifier = nn.Linear(32, 2) + self.configure_model_called_count += 1 + + def forward(self, x): + x = self.backbone(x) + return self.classifier(x) + + def training_step(self, batch, batch_idx): + return self.forward(batch).sum() + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.1) + + model = TestModel() + callback = TrackingFinetuningCallback() + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[callback], + max_epochs=1, + limit_train_batches=1, + ) + + trainer.fit(model, torch.randn(10, 32)) + assert model.configure_model_called_count == 1