Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:meth", "setup"),
("py:meth", "test_step"),
("py:meth", "toggle_optimizer"),
("py:meth", "toggled_optimizer"),
("py:class", "torch.ScriptModule"),
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload"),
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision"),
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/model/manual_optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ To manually optimize, do the following:
* ``optimizer.zero_grad()`` to clear the gradients from the previous training step
* ``self.manual_backward(loss)`` instead of ``loss.backward()``
* ``optimizer.step()`` to update your model parameters
* ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()`` if needed
* ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()``, or ``self.toggled_optimizer()`` if needed

Here is a minimal example of manual optimization.

Expand Down
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593))


- Add `toggled_optimizer(optimizer)` method to the LightningModule, which is a context manager version of `toggle_optimize` and `untoggle_optimizer` ([#20771](https://github.com/Lightning-AI/pytorch-lightning/pull/20771))


- For cross-device local checkpoints, instruct users to install `fsspec>=2025.5.0` if unavailable ([#20780](https://github.com/Lightning-AI/pytorch-lightning/pull/20780))


Expand Down
25 changes: 25 additions & 0 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,31 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) ->
# save memory
self._param_requires_grad_state = {}

@contextmanager
def toggled_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> Generator:
"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to
prevent dangling gradients in multiple-optimizer setup. Combines :meth:`toggle_optimizer` and
:meth:`untoggle_optimizer` into context manager.

Args:
optimizer: The optimizer to toggle.

Example::

def training_step(...):
opt = self.optimizers()
with self.toggled_optimizer(opt):
loss = ...
opt.zero_grad()
self.manual_backward(loss)
opt.step()

"""
try:
yield self.toggle_optimizer(optimizer)
finally:
self.untoggle_optimizer(optimizer)

def clip_gradients(
self,
optimizer: Optimizer,
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,22 @@ def test_1_optimizer_toggle_model():
assert not model._param_requires_grad_state


def test_1_optimizer_toggle_model_context_manager():
"""Test toggle_model runs when only one optimizer is used."""
model = BoringModel()
trainer = Mock()
model.trainer = trainer
params = model.parameters()
optimizer = torch.optim.SGD(params, lr=0.1)
trainer.optimizers = [optimizer]

assert not model._param_requires_grad_state
# toggle optimizer was failing with a single optimizer
with model.toggled_optimizer(optimizer):
assert model._param_requires_grad_state
assert not model._param_requires_grad_state


def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmp_path):
class TestModel(BoringModel):
def __init__(self):
Expand Down
Loading