diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 89c1c15d0413f..82ecc99c6810a 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -88,13 +88,19 @@ with the source of each hook indicated: │ ├── [LightningModule] │ ├── [LightningModule.configure_shared_model()] │ ├── [LightningModule.configure_model()] + | | + │ ├── on_before_optimizer_setup() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ | │ ├── Strategy.restore_checkpoint_before_setup │ │ ├── [LightningModule.on_load_checkpoint()] │ │ ├── [LightningModule.load_state_dict()] │ │ ├── [LightningDataModule.load_state_dict()] │ │ ├── [Callbacks.on_load_checkpoint()] │ │ └── [Callbacks.load_state_dict()] - │ └── [Strategy] + | | + │ └── [Strategy] (configures optimizers, lr schedulers, precision, etc.) │ ├── on_fit_start() │ ├── [Callbacks] diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 3bfb609465a83..a1fdf0b8857fd 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -277,6 +277,14 @@ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called after ``loss.backward()`` and before optimizers are stepped.""" + def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. + + Useful when you need to make changes to the model before the optimizers are set up (e.g. freezing layers). + + """ + def on_before_optimizer_step( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer ) -> None: diff --git a/src/lightning/pytorch/callbacks/lambda_function.py b/src/lightning/pytorch/callbacks/lambda_function.py index f04b2d777deb3..f8348ccb72a53 100644 --- a/src/lightning/pytorch/callbacks/lambda_function.py +++ b/src/lightning/pytorch/callbacks/lambda_function.py @@ -69,6 +69,7 @@ def __init__( on_load_checkpoint: Optional[Callable] = None, on_before_backward: Optional[Callable] = None, on_after_backward: Optional[Callable] = None, + on_before_optimizer_setup: Optional[Callable] = None, on_before_optimizer_step: Optional[Callable] = None, on_before_zero_grad: Optional[Callable] = None, on_predict_start: Optional[Callable] = None, diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 0b0ab14244e38..a979232d7ac72 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -295,6 +295,29 @@ def on_after_backward(self) -> None: """ + def on_before_optimizer_setup(self) -> None: + """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`. + + This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created. + It’s particularly useful for callbacks such as + :class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen + prior to optimizer setup. + + This hook runs once in fit stage, after the model + has been fully instantiated by ``configure_model``, but before optimizers are created by + ``configure_optimizers``. + + Example:: + + class MyFinetuneCallback(Callback): + def on_before_optimizer_setup(self, trainer, pl_module): + # freeze the backbone before optimizers are created + for param in pl_module.backbone.parameters(): + param.requires_grad = False + + """ + def on_before_optimizer_step(self, optimizer: Optimizer) -> None: """Called before ``optimizer.step()``. diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index c1ee0013bfa19..01264d750ccc0 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -35,6 +35,7 @@ class _LogOptions(TypedDict): "on_after_backward": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), + "on_before_optimizer_setup": None, "on_before_optimizer_step": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 5768c507e2e3f..c0d6f18166652 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -989,6 +989,11 @@ def _run( log.debug(f"{self.__class__.__name__}: configuring model") call._call_configure_model(self) + # run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured + if self.state.fn == TrainerFn.FITTING: + call._call_callback_hooks(self, "on_before_optimizer_setup") + call._call_lightning_module_hook(self, "on_before_optimizer_setup") + # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") diff --git a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py b/tests/tests_pytorch/callbacks/test_callback_hooks.py similarity index 54% rename from tests/tests_pytorch/callbacks/test_callback_hook_outputs.py rename to tests/tests_pytorch/callbacks/test_callback_hooks.py index 366a924a5867c..5e4bf1096a4fb 100644 --- a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py +++ b/tests/tests_pytorch/callbacks/test_callback_hooks.py @@ -56,3 +56,55 @@ def on_test_batch_end(self, outputs, *_): assert any(isinstance(c, CB) for c in trainer.callbacks) trainer.fit(model) + + +def test_on_before_optimizer_setup_is_called_in_correct_order(tmp_path): + """Ensure `on_before_optimizer_setup` runs after `configure_model` but before `configure_optimizers`.""" + + order = [] + + class TestCallback(Callback): + def setup(self, trainer, pl_module, stage=None): + order.append("setup") + assert pl_module.layer is None + assert len(trainer.optimizers) == 0 + + def on_before_optimizer_setup(self, trainer, pl_module): + order.append("on_before_optimizer_setup") + # configure_model should already have been called + assert pl_module.layer is not None + # but optimizers are not yet created + assert len(trainer.optimizers) == 0 + + def on_fit_start(self, trainer, pl_module): + order.append("on_fit_start") + # optimizers should now exist + assert len(trainer.optimizers) == 1 + assert pl_module.layer is not None + + class DemoModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = None + + def configure_model(self): + from torch import nn + + self.layer = nn.Linear(32, 2) + + model = DemoModel() + + trainer = Trainer( + callbacks=TestCallback(), + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + enable_model_summary=False, + log_every_n_steps=1, + ) + + trainer.fit(model) + + # Verify call order + assert order == ["setup", "on_before_optimizer_setup", "on_fit_start"] diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index e943d0533cab5..77f8bbb642b2e 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -477,6 +477,8 @@ def training_step(self, batch, batch_idx): # DeepSpeed needs the batch size to figure out throughput logging *([{"name": "train_dataloader"}] if using_deepspeed else []), {"name": "configure_model"}, + {"name": "Callback.on_before_optimizer_setup", "args": (trainer, model)}, + {"name": "on_before_optimizer_setup"}, {"name": "configure_optimizers"}, {"name": "Callback.on_fit_start", "args": (trainer, model)}, {"name": "on_fit_start"}, @@ -574,6 +576,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path): {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, {"name": "configure_model"}, + {"name": "Callback.on_before_optimizer_setup", "args": (trainer, model)}, + {"name": "on_before_optimizer_setup"}, {"name": "on_load_checkpoint", "args": (loaded_ckpt,)}, {"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)}, {"name": "Callback.load_state_dict", "args": ({"foo": True},)}, @@ -654,6 +658,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): {"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}}, {"name": "setup", "kwargs": {"stage": "fit"}}, {"name": "configure_model"}, + {"name": "Callback.on_before_optimizer_setup", "args": (trainer, model)}, + {"name": "on_before_optimizer_setup"}, {"name": "on_load_checkpoint", "args": (loaded_ckpt,)}, {"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)}, {"name": "Callback.load_state_dict", "args": ({"foo": True},)}, diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index d3d355edb003b..8166113974566 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -43,6 +43,7 @@ def test_fx_validator(): callbacks_func = { "on_before_backward", "on_after_backward", + "on_before_optimizer_setup", "on_before_optimizer_step", "on_before_zero_grad", "on_fit_end", @@ -83,6 +84,7 @@ def test_fx_validator(): } not_supported = { + "on_before_optimizer_setup", "on_fit_end", "on_fit_start", "on_exception", @@ -198,6 +200,7 @@ def test_fx_validator_integration(tmp_path): "setup": "You can't", "configure_model": "You can't", "configure_optimizers": "You can't", + "on_before_optimizer_setup": "You can't", "on_fit_start": "You can't", "train_dataloader": "You can't", "val_dataloader": "You can't",