Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/source-pytorch/common/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,19 @@ with the source of each hook indicated:
│ ├── [LightningModule]
│ ├── [LightningModule.configure_shared_model()]
│ ├── [LightningModule.configure_model()]
| |
│ ├── on_before_optimizer_setup()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ |
│ ├── Strategy.restore_checkpoint_before_setup
│ │ ├── [LightningModule.on_load_checkpoint()]
│ │ ├── [LightningModule.load_state_dict()]
│ │ ├── [LightningDataModule.load_state_dict()]
│ │ ├── [Callbacks.on_load_checkpoint()]
│ │ └── [Callbacks.load_state_dict()]
│ └── [Strategy]
| |
│ └── [Strategy] (configures optimizers, lr schedulers, precision, etc.)
├── on_fit_start()
│ ├── [Callbacks]
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/pytorch/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,14 @@ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called after ``loss.backward()`` and before optimizers are stepped."""

def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
Useful when you need to make changes to the model before the optimizers are set up (e.g. freezing layers).
"""

def on_before_optimizer_step(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
) -> None:
Expand Down
23 changes: 23 additions & 0 deletions src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,29 @@ def on_after_backward(self) -> None:
"""

def on_before_optimizer_setup(self) -> None:
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created.
It’s particularly useful for callbacks such as
:class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen
prior to optimizer setup.
This hook runs once in fit stage, after the model
has been fully instantiated by ``configure_model``, but before optimizers are created by
``configure_optimizers``.
Example::
class MyFinetuneCallback(Callback):
def on_before_optimizer_setup(self, trainer, pl_module):
# freeze the backbone before optimizers are created
for param in pl_module.backbone.parameters():
param.requires_grad = False
"""

def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
"""Called before ``optimizer.step()``.
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,11 @@ def _run(
log.debug(f"{self.__class__.__name__}: configuring model")
call._call_configure_model(self)

# run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_before_optimizer_setup")
call._call_lightning_module_hook(self, "on_before_optimizer_setup")

# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,44 @@ def on_test_batch_end(self, outputs, *_):
assert any(isinstance(c, CB) for c in trainer.callbacks)

trainer.fit(model)


def test_callback_on_before_optimizer_setup(tmp_path):
"""Tests that on_before_optimizer_step is called as expected."""

class CB(Callback):
def setup(self, trainer, pl_module, stage=None):
assert len(trainer.optimizers) == 0
assert pl_module.layer is None # called before `LightningModule.configure_model`

def on_before_optimizer_setup(self, trainer, pl_module):
assert len(trainer.optimizers) == 0 # `LightningModule.configure_optimizers` hasn't been called yet
assert pl_module.layer is not None # called after `LightningModule.configure_model`

def on_fit_start(self, trainer, pl_module):
assert len(trainer.optimizers) == 1
assert pl_module.layer is not None # called after `LightningModule.configure_model`

class DemoModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = None # initialize layer in `configure_model`

def configure_model(self):
import torch.nn as nn

self.layer = nn.Linear(32, 2)

model = DemoModel()

trainer = Trainer(
callbacks=CB(),
default_root_dir=tmp_path,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
enable_model_summary=False,
)

trainer.fit(model)
Loading