Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5fd675c
call configure_module before freeze_before_training
Nov 18, 2024
91775f7
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 18, 2024
9da9e7d
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 18, 2024
faef707
Merge branch 'master' into chualan/fix-19658
lantiga Nov 19, 2024
90ff8f0
remove bad fix
Nov 21, 2024
a205c4a
second fix and test case
Nov 22, 2024
ef35dca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2024
0e570a8
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 22, 2024
56d05a3
remove print statement
Nov 22, 2024
1c040d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2024
bfa0fd4
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 25, 2024
8ba644a
change assertion order for setup() and configure_model() in test_hook…
Nov 25, 2024
1d8ef66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2024
11c8be4
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 26, 2024
9e53990
Merge branch 'master' into chualan/fix-19658
lantiga Nov 26, 2024
4567c49
Merge branch 'master' into chualan/fix-19658
lantiga Dec 9, 2024
c6a77a9
Merge branch 'master' into chualan/fix-19658
lantiga Dec 10, 2024
5be022e
Merge branch 'master' into chualan/fix-19658
lantiga Dec 11, 2024
c20c173
Merge branch 'master' into chualan/fix-19658
lantiga Dec 11, 2024
344822b
Merge branch 'master' into chualan/fix-19658
Borda Apr 16, 2025
c7f02ed
Merge branch 'master' into chualan/fix-19658
Borda Apr 16, 2025
3661d79
Merge branch 'master' into chualan/fix-19658
Borda Aug 8, 2025
173ae26
Merge branch 'master' into chualan/fix-19658
Borda Sep 10, 2025
ad3375e
update
deependujha Oct 6, 2025
cfe7a81
update
deependujha Oct 6, 2025
221602f
introduce `on_before_optimizer_setup` hook
deependujha Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/lightning/pytorch/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,14 @@ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called after ``loss.backward()`` and before optimizers are stepped."""

def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
Useful when you need to make changes to the model before the optimizers are set up (e.g. freezing layers).
"""

def on_before_optimizer_step(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
) -> None:
Expand Down
25 changes: 13 additions & 12 deletions src/lightning/pytorch/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# compatibility to load from old checkpoints before PR #11887
self._internal_optimizer_metadata = state_dict # type: ignore[assignment]

@override
def on_before_optimizer_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# freeze the required modules before optimizers are created & after `configure_model` is run
self.freeze_before_training(pl_module)

from lightning.pytorch.strategies import DeepSpeedStrategy

if isinstance(trainer.strategy, DeepSpeedStrategy):
raise NotImplementedError(
"The Finetuning callback does not support running with the DeepSpeed strategy."
" Choose a different strategy or disable the callback."
)

@override
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# restore the param_groups created during the previous training.
Expand Down Expand Up @@ -273,18 +286,6 @@ def unfreeze_and_add_param_group(
if params:
optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})

@override
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self.freeze_before_training(pl_module)

from lightning.pytorch.strategies import DeepSpeedStrategy

if isinstance(trainer.strategy, DeepSpeedStrategy):
raise NotImplementedError(
"The Finetuning callback does not support running with the DeepSpeed strategy."
" Choose a different strategy or disable the callback."
)

@staticmethod
def _apply_mapping_to_param_groups(param_groups: list[dict[str, Any]], mapping: dict) -> list[dict[str, Any]]:
output = []
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
on_load_checkpoint: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_optimizer_setup: Optional[Callable] = None,
on_before_optimizer_step: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
on_predict_start: Optional[Callable] = None,
Expand Down
23 changes: 23 additions & 0 deletions src/lightning/pytorch/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,29 @@ def on_after_backward(self) -> None:
"""

def on_before_optimizer_setup(self) -> None:
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created.
It’s particularly useful for callbacks such as
:class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen
prior to optimizer setup.
This hook runs once in fit stage, after the model
has been fully instantiated by ``configure_model``.
Example::
class MyFinetuneCallback(Callback):
def on_before_optimizer_setup(self, trainer, pl_module):
# freeze the backbone before optimizers are created
for param in pl_module.backbone.parameters():
param.requires_grad = False
"""
pass

def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
"""Called before ``optimizer.step()``.
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,11 @@ def _run(
log.debug(f"{self.__class__.__name__}: configuring model")
call._call_configure_model(self)

# run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_before_optimizer_setup")
call._call_lightning_module_hook(self, "on_before_optimizer_setup")

# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
Expand Down
48 changes: 48 additions & 0 deletions tests/tests_pytorch/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,51 @@ def test_unsupported_strategies(tmp_path):
trainer = Trainer(accelerator="cpu", strategy="deepspeed", callbacks=[callback])
with pytest.raises(NotImplementedError, match="does not support running with the DeepSpeed strategy"):
callback.setup(trainer, model, stage=None)


def test_finetuning_with_configure_model(tmp_path):
"""Test that BaseFinetuning works correctly with configure_model by ensuring freeze_before_training is called after
configure_model but before training starts."""

class TrackingFinetuningCallback(BaseFinetuning):
def __init__(self):
super().__init__()

def freeze_before_training(self, pl_module):
assert hasattr(pl_module, "backbone"), "backbone should be configured before freezing"
self.freeze(pl_module.backbone)

def finetune_function(self, pl_module, epoch, optimizer):
pass

class TestModel(LightningModule):
def __init__(self):
super().__init__()
self.configure_model_called_count = 0

def configure_model(self):
self.backbone = nn.Linear(32, 32)
self.classifier = nn.Linear(32, 2)
self.configure_model_called_count += 1

def forward(self, x):
x = self.backbone(x)
return self.classifier(x)

def training_step(self, batch, batch_idx):
return self.forward(batch).sum()

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)

model = TestModel()
callback = TrackingFinetuningCallback()
trainer = Trainer(
default_root_dir=tmp_path,
callbacks=[callback],
max_epochs=1,
limit_train_batches=1,
)

trainer.fit(model, torch.randn(10, 32))
assert model.configure_model_called_count == 1
Loading