diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index 90400b1df491d..62cd21fc127f4 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -487,6 +487,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: ("py:meth", "setup"), ("py:meth", "test_step"), ("py:meth", "toggle_optimizer"), + ("py:meth", "toggled_optimizer"), ("py:class", "torch.ScriptModule"), ("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload"), ("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision"), diff --git a/docs/source-pytorch/model/manual_optimization.rst b/docs/source-pytorch/model/manual_optimization.rst index 150f04793eae6..4c7400c0457ca 100644 --- a/docs/source-pytorch/model/manual_optimization.rst +++ b/docs/source-pytorch/model/manual_optimization.rst @@ -17,7 +17,7 @@ To manually optimize, do the following: * ``optimizer.zero_grad()`` to clear the gradients from the previous training step * ``self.manual_backward(loss)`` instead of ``loss.backward()`` * ``optimizer.step()`` to update your model parameters - * ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()`` if needed + * ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()``, or ``self.toggled_optimizer()`` if needed Here is a minimal example of manual optimization. diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index eaca4443ae434..5a03b7f953c5b 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -11,6 +11,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) + +- Add `toggled_optimizer(optimizer)` method to the LightningModule, which is a context manager version of `toggle_optimize` and `untoggle_optimizer` ([#20771](https://github.com/Lightning-AI/pytorch-lightning/pull/20771)) + + - For cross-device local checkpoints, instruct users to install `fsspec>=2025.5.0` if unavailable ([#20780](https://github.com/Lightning-AI/pytorch-lightning/pull/20780)) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 8b2387fcea481..c484a95c6c632 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1141,6 +1141,32 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> # save memory self._param_requires_grad_state = {} + @contextmanager + def toggled_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> Generator: + """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to + prevent dangling gradients in multiple-optimizer setup. Combines :meth:`toggle_optimizer` and + :meth:`untoggle_optimizer` into context manager. + + Args: + optimizer: The optimizer to toggle. + + Example:: + + def training_step(...): + opt = self.optimizers() + with self.toggled_optimizer(opt): + loss = ... + opt.zero_grad() + self.manual_backward(loss) + opt.step() + + """ + self.toggle_optimizer(optimizer) + try: + yield + finally: + self.untoggle_optimizer(optimizer) + def clip_gradients( self, optimizer: Optimizer, diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 2036014762ebf..c33488a4f2626 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -119,6 +119,22 @@ def test_1_optimizer_toggle_model(): assert not model._param_requires_grad_state +def test_optimizer_toggle_model_context_manager(): + """Test toggle_model runs when only one optimizer is used.""" + model = BoringModel() + trainer = Mock() + model.trainer = trainer + params = model.parameters() + optimizer = torch.optim.SGD(params, lr=0.1) + trainer.optimizers = [optimizer] + + assert not model._param_requires_grad_state + # toggle optimizer was failing with a single optimizer + with model.toggled_optimizer(optimizer): + assert model._param_requires_grad_state + assert not model._param_requires_grad_state + + def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmp_path): class TestModel(BoringModel): def __init__(self):