Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/source-pytorch/common/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,19 @@ with the source of each hook indicated:
│ ├── [LightningModule]
│ ├── [LightningModule.configure_shared_model()]
│ ├── [LightningModule.configure_model()]
| |
│ ├── on_before_optimizer_setup()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ |
│ ├── Strategy.restore_checkpoint_before_setup
│ │ ├── [LightningModule.on_load_checkpoint()]
│ │ ├── [LightningModule.load_state_dict()]
│ │ ├── [LightningDataModule.load_state_dict()]
│ │ ├── [Callbacks.on_load_checkpoint()]
│ │ └── [Callbacks.load_state_dict()]
│ └── [Strategy]
| |
│ └── [Strategy] (configures optimizers, lr schedulers, precision, etc.)
├── on_fit_start()
│ ├── [Callbacks]
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/pytorch/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,14 @@ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called after ``loss.backward()`` and before optimizers are stepped."""

def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.

Useful when you need to make changes to the model before the optimizers are set up (e.g. freezing layers).

"""

def on_before_optimizer_step(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
on_load_checkpoint: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_optimizer_setup: Optional[Callable] = None,
on_before_optimizer_step: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
on_predict_start: Optional[Callable] = None,
Expand Down
23 changes: 23 additions & 0 deletions src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,29 @@ def on_after_backward(self) -> None:

"""

def on_before_optimizer_setup(self) -> None:
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.

This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created.
It’s particularly useful for callbacks such as
:class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen
prior to optimizer setup.

This hook runs once in fit stage, after the model
has been fully instantiated by ``configure_model``, but before optimizers are created by
``configure_optimizers``.

Example::

class MyFinetuneCallback(Callback):
def on_before_optimizer_setup(self, trainer, pl_module):
# freeze the backbone before optimizers are created
for param in pl_module.backbone.parameters():
param.requires_grad = False

"""

def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
"""Called before ``optimizer.step()``.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class _LogOptions(TypedDict):
"on_after_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_optimizer_setup": None,
"on_before_optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,11 @@ def _run(
log.debug(f"{self.__class__.__name__}: configuring model")
call._call_configure_model(self)

# run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_before_optimizer_setup")
call._call_lightning_module_hook(self, "on_before_optimizer_setup")

# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,55 @@ def on_test_batch_end(self, outputs, *_):
assert any(isinstance(c, CB) for c in trainer.callbacks)

trainer.fit(model)


def test_on_before_optimizer_setup_is_called_in_correct_order(tmp_path):
"""Ensure `on_before_optimizer_setup` runs after `configure_model` but before `configure_optimizers`."""

order = []

class TestCallback(Callback):
def setup(self, trainer, pl_module, stage=None):
order.append("setup")
assert pl_module.layer is None
assert len(trainer.optimizers) == 0

def on_before_optimizer_setup(self, trainer, pl_module):
order.append("on_before_optimizer_setup")
# configure_model should already have been called
assert pl_module.layer is not None
# but optimizers are not yet created
assert len(trainer.optimizers) == 0

def on_fit_start(self, trainer, pl_module):
order.append("on_fit_start")
# optimizers should now exist
assert len(trainer.optimizers) == 1
assert pl_module.layer is not None

class DemoModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = None

def configure_model(self):
from torch import nn

self.layer = nn.Linear(32, 2)

model = DemoModel()

trainer = Trainer(
callbacks=TestCallback(),
default_root_dir=tmp_path,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
enable_model_summary=False,
log_every_n_steps=1,
)

trainer.fit(model)

# Verify call order
assert order == ["setup", "on_before_optimizer_setup", "on_fit_start"]
6 changes: 6 additions & 0 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ def training_step(self, batch, batch_idx):
# DeepSpeed needs the batch size to figure out throughput logging
*([{"name": "train_dataloader"}] if using_deepspeed else []),
{"name": "configure_model"},
{"name": "Callback.on_before_optimizer_setup", "args": (trainer, model)},
{"name": "on_before_optimizer_setup"},
{"name": "configure_optimizers"},
{"name": "Callback.on_fit_start", "args": (trainer, model)},
{"name": "on_fit_start"},
Expand Down Expand Up @@ -574,6 +576,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
{"name": "setup", "kwargs": {"stage": "fit"}},
{"name": "configure_model"},
{"name": "Callback.on_before_optimizer_setup", "args": (trainer, model)},
{"name": "on_before_optimizer_setup"},
{"name": "on_load_checkpoint", "args": (loaded_ckpt,)},
{"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)},
{"name": "Callback.load_state_dict", "args": ({"foo": True},)},
Expand Down Expand Up @@ -654,6 +658,8 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
{"name": "setup", "kwargs": {"stage": "fit"}},
{"name": "configure_model"},
{"name": "Callback.on_before_optimizer_setup", "args": (trainer, model)},
{"name": "on_before_optimizer_setup"},
{"name": "on_load_checkpoint", "args": (loaded_ckpt,)},
{"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)},
{"name": "Callback.load_state_dict", "args": ({"foo": True},)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_fx_validator():
callbacks_func = {
"on_before_backward",
"on_after_backward",
"on_before_optimizer_setup",
"on_before_optimizer_step",
"on_before_zero_grad",
"on_fit_end",
Expand Down Expand Up @@ -83,6 +84,7 @@ def test_fx_validator():
}

not_supported = {
"on_before_optimizer_setup",
"on_fit_end",
"on_fit_start",
"on_exception",
Expand Down Expand Up @@ -198,6 +200,7 @@ def test_fx_validator_integration(tmp_path):
"setup": "You can't",
"configure_model": "You can't",
"configure_optimizers": "You can't",
"on_before_optimizer_setup": "You can't",
"on_fit_start": "You can't",
"train_dataloader": "You can't",
"val_dataloader": "You can't",
Expand Down
Loading