Skip to content

Commit 221602f

Browse files
committed
introduce on_before_optimizer_setup hook
1 parent cfe7a81 commit 221602f

File tree

5 files changed

+41
-2
lines changed

5 files changed

+41
-2
lines changed

src/lightning/pytorch/callbacks/callback.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,14 @@ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
277277
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
278278
"""Called after ``loss.backward()`` and before optimizers are stepped."""
279279

280+
def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
281+
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
282+
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
283+
284+
Useful when you need to make changes to the model before the optimizers are set up (e.g. freezing layers).
285+
286+
"""
287+
280288
def on_before_optimizer_step(
281289
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
282290
) -> None:

src/lightning/pytorch/callbacks/finetuning.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
105105
self._internal_optimizer_metadata = state_dict # type: ignore[assignment]
106106

107107
@override
108-
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
109-
# freeze the required modules before training
108+
def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
109+
# freeze the required modules before optimizers are created & after `configure_model` is run
110110
self.freeze_before_training(pl_module)
111111

112112
from lightning.pytorch.strategies import DeepSpeedStrategy
@@ -117,6 +117,8 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
117117
" Choose a different strategy or disable the callback."
118118
)
119119

120+
@override
121+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
120122
# restore the param_groups created during the previous training.
121123
if self._restarting:
122124
named_parameters = dict(pl_module.named_parameters())

src/lightning/pytorch/callbacks/lambda_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
on_load_checkpoint: Optional[Callable] = None,
7070
on_before_backward: Optional[Callable] = None,
7171
on_after_backward: Optional[Callable] = None,
72+
on_before_optimizer_setup: Optional[Callable] = None,
7273
on_before_optimizer_step: Optional[Callable] = None,
7374
on_before_zero_grad: Optional[Callable] = None,
7475
on_predict_start: Optional[Callable] = None,

src/lightning/pytorch/core/hooks.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,29 @@ def on_after_backward(self) -> None:
295295
296296
"""
297297

298+
def on_before_optimizer_setup(self) -> None:
299+
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
300+
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
301+
302+
This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created.
303+
It’s particularly useful for callbacks such as
304+
:class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen
305+
prior to optimizer setup.
306+
307+
This hook runs once in fit stage, after the model
308+
has been fully instantiated by ``configure_model``.
309+
310+
Example::
311+
312+
class MyFinetuneCallback(Callback):
313+
def on_before_optimizer_setup(self, trainer, pl_module):
314+
# freeze the backbone before optimizers are created
315+
for param in pl_module.backbone.parameters():
316+
param.requires_grad = False
317+
318+
"""
319+
pass
320+
298321
def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
299322
"""Called before ``optimizer.step()``.
300323

src/lightning/pytorch/trainer/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,11 @@ def _run(
989989
log.debug(f"{self.__class__.__name__}: configuring model")
990990
call._call_configure_model(self)
991991

992+
# run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured
993+
if self.state.fn == TrainerFn.FITTING:
994+
call._call_callback_hooks(self, "on_before_optimizer_setup")
995+
call._call_lightning_module_hook(self, "on_before_optimizer_setup")
996+
992997
# check if we should delay restoring checkpoint till later
993998
if not self.strategy.restore_checkpoint_after_setup:
994999
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")

0 commit comments

Comments
 (0)