Skip to content

Commit e7e512d

Browse files
committed
feat: introduce new callback on_before_on_before_optimizer_setup
1 parent 4dc1dbc commit e7e512d

File tree

4 files changed

+77
-0
lines changed

4 files changed

+77
-0
lines changed

src/lightning/pytorch/callbacks/callback.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,14 @@ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
277277
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
278278
"""Called after ``loss.backward()`` and before optimizers are stepped."""
279279

280+
def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
281+
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
282+
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
283+
284+
Useful when you need to make changes to the model before the optimizers are set up (e.g. freezing layers).
285+
286+
"""
287+
280288
def on_before_optimizer_step(
281289
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
282290
) -> None:

src/lightning/pytorch/core/hooks.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,29 @@ def on_after_backward(self) -> None:
295295
296296
"""
297297

298+
def on_before_optimizer_setup(self) -> None:
299+
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
300+
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
301+
302+
This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created.
303+
It’s particularly useful for callbacks such as
304+
:class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen
305+
prior to optimizer setup.
306+
307+
This hook runs once in fit stage, after the model
308+
has been fully instantiated by ``configure_model``, but before optimizers are created by
309+
``configure_optimizers``.
310+
311+
Example::
312+
313+
class MyFinetuneCallback(Callback):
314+
def on_before_optimizer_setup(self, trainer, pl_module):
315+
# freeze the backbone before optimizers are created
316+
for param in pl_module.backbone.parameters():
317+
param.requires_grad = False
318+
319+
"""
320+
298321
def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
299322
"""Called before ``optimizer.step()``.
300323

src/lightning/pytorch/trainer/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,11 @@ def _run(
989989
log.debug(f"{self.__class__.__name__}: configuring model")
990990
call._call_configure_model(self)
991991

992+
# run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured
993+
if self.state.fn == TrainerFn.FITTING:
994+
call._call_callback_hooks(self, "on_before_optimizer_setup")
995+
call._call_lightning_module_hook(self, "on_before_optimizer_setup")
996+
992997
# check if we should delay restoring checkpoint till later
993998
if not self.strategy.restore_checkpoint_after_setup:
994999
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")

tests/tests_pytorch/callbacks/test_callback_hook_outputs.py renamed to tests/tests_pytorch/callbacks/test_callback_hooks.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,44 @@ def on_test_batch_end(self, outputs, *_):
5656
assert any(isinstance(c, CB) for c in trainer.callbacks)
5757

5858
trainer.fit(model)
59+
60+
61+
def test_callback_on_before_optimizer_setup(tmp_path):
62+
"""Tests that on_before_optimizer_step is called as expected."""
63+
64+
class CB(Callback):
65+
def setup(self, trainer, pl_module, stage=None):
66+
assert len(trainer.optimizers) == 0
67+
assert pl_module.layer is None # setup is called before `LightningModule.configure_model`
68+
69+
def on_before_optimizer_setup(self, trainer, pl_module):
70+
assert len(trainer.optimizers) == 0
71+
assert pl_module.layer is not None # called after `LightningModule.configure_model`
72+
73+
def on_fit_start(self, trainer, pl_module):
74+
assert len(trainer.optimizers) == 1
75+
assert pl_module.layer is not None # called after `LightningModule.configure_model`
76+
77+
class DemoModel(BoringModel):
78+
def __init__(self):
79+
super().__init__()
80+
self.layer = None # initialize layer in `configure_model`
81+
82+
def configure_model(self):
83+
import torch.nn as nn
84+
85+
self.layer = nn.Linear(32, 2)
86+
87+
model = DemoModel()
88+
89+
trainer = Trainer(
90+
callbacks=CB(),
91+
default_root_dir=tmp_path,
92+
limit_train_batches=2,
93+
limit_val_batches=2,
94+
max_epochs=1,
95+
log_every_n_steps=1,
96+
enable_model_summary=False,
97+
)
98+
99+
trainer.fit(model)

0 commit comments

Comments
 (0)