Skip to content

Commit a19b3cb

Browse files
committed
fx validator fix
1 parent 050f583 commit a19b3cb

File tree

4 files changed

+22
-18
lines changed

4 files changed

+22
-18
lines changed

src/lightning/pytorch/core/hooks.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -295,28 +295,28 @@ def on_after_backward(self) -> None:
295295
296296
"""
297297

298-
def on_before_optimizer_setup(self) -> None:
299-
"""Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
300-
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
298+
# def on_before_optimizer_setup(self) -> None:
299+
# """Called after :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` but before
300+
# :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_optimizers`.
301301

302-
This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created.
303-
It’s particularly useful for callbacks such as
304-
:class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen
305-
prior to optimizer setup.
302+
# This hook provides a safe point to modify, freeze, or inspect model parameters before optimizers are created.
303+
# It’s particularly useful for callbacks such as
304+
# :class:`~lightning.pytorch.callbacks.finetuning.BaseFinetuning`, where parameters must be frozen
305+
# prior to optimizer setup.
306306

307-
This hook runs once in fit stage, after the model
308-
has been fully instantiated by ``configure_model``, but before optimizers are created by
309-
``configure_optimizers``.
307+
# This hook runs once in fit stage, after the model
308+
# has been fully instantiated by ``configure_model``, but before optimizers are created by
309+
# ``configure_optimizers``.
310310

311-
Example::
311+
# Example::
312312

313-
class MyFinetuneCallback(Callback):
314-
def on_before_optimizer_setup(self, trainer, pl_module):
315-
# freeze the backbone before optimizers are created
316-
for param in pl_module.backbone.parameters():
317-
param.requires_grad = False
313+
# class MyFinetuneCallback(Callback):
314+
# def on_before_optimizer_setup(self, trainer, pl_module):
315+
# # freeze the backbone before optimizers are created
316+
# for param in pl_module.backbone.parameters():
317+
# param.requires_grad = False
318318

319-
"""
319+
# """
320320

321321
def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
322322
"""Called before ``optimizer.step()``.

src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class _LogOptions(TypedDict):
3535
"on_after_backward": _LogOptions(
3636
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
3737
),
38+
"on_before_optimizer_setup": None,
3839
"on_before_optimizer_step": _LogOptions(
3940
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
4041
),

src/lightning/pytorch/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ def _run(
992992
# run hook `on_before_optimizer_setup` before optimizers are set up & after model is configured
993993
if self.state.fn == TrainerFn.FITTING:
994994
call._call_callback_hooks(self, "on_before_optimizer_setup")
995-
call._call_lightning_module_hook(self, "on_before_optimizer_setup")
995+
# call._call_lightning_module_hook(self, "on_before_optimizer_setup")
996996

997997
# check if we should delay restoring checkpoint till later
998998
if not self.strategy.restore_checkpoint_after_setup:

tests/tests_pytorch/trainer/logging_/test_logger_connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_fx_validator():
4343
callbacks_func = {
4444
"on_before_backward",
4545
"on_after_backward",
46+
"on_before_optimizer_setup",
4647
"on_before_optimizer_step",
4748
"on_before_zero_grad",
4849
"on_fit_end",
@@ -83,6 +84,7 @@ def test_fx_validator():
8384
}
8485

8586
not_supported = {
87+
"on_before_optimizer_setup",
8688
"on_fit_end",
8789
"on_fit_start",
8890
"on_exception",
@@ -198,6 +200,7 @@ def test_fx_validator_integration(tmp_path):
198200
"setup": "You can't",
199201
"configure_model": "You can't",
200202
"configure_optimizers": "You can't",
203+
"on_before_optimizer_setup": "You can't",
201204
"on_fit_start": "You can't",
202205
"train_dataloader": "You can't",
203206
"val_dataloader": "You can't",

0 commit comments

Comments
 (0)