Skip to content

Commit bb63794

Browse files
committed
add toggled_optimizer to LightningModule
1 parent 6da480d commit bb63794

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,30 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) ->
11411141
# save memory
11421142
self._param_requires_grad_state = {}
11431143

1144+
@contextmanager
1145+
def toggled_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> Generator:
1146+
"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to
1147+
prevent dangling gradients in multiple-optimizer setup. Combines :meth:`toggle_optimizer` and
1148+
:meth:`untoggle_optimizer` into context manager.
1149+
1150+
Args:
1151+
optimizer: The optimizer to untoggle.
1152+
1153+
Example::
1154+
1155+
def training_step(...):
1156+
opt = self.optimizers()
1157+
with self.toggled_optimizer(opt):
1158+
loss = ...
1159+
opt.zero_grad()
1160+
self.manual_backward(loss)
1161+
opt.step()
1162+
"""
1163+
try:
1164+
yield self.toggle_optimizer(optimizer)
1165+
finally:
1166+
self.untoggle_optimizer(optimizer)
1167+
11441168
def clip_gradients(
11451169
self,
11461170
optimizer: Optimizer,

tests/tests_pytorch/core/test_lightning_module.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,21 @@ def test_1_optimizer_toggle_model():
118118
model.untoggle_optimizer(optimizer)
119119
assert not model._param_requires_grad_state
120120

121+
def test_1_optimizer_toggle_model_context_manager():
122+
"""Test toggle_model runs when only one optimizer is used."""
123+
model = BoringModel()
124+
trainer = Mock()
125+
model.trainer = trainer
126+
params = model.parameters()
127+
optimizer = torch.optim.SGD(params, lr=0.1)
128+
trainer.optimizers = [optimizer]
129+
130+
assert not model._param_requires_grad_state
131+
# toggle optimizer was failing with a single optimizer
132+
with model.toggled_optimizer(optimizer):
133+
assert model._param_requires_grad_state
134+
assert not model._param_requires_grad_state
135+
121136

122137
def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmp_path):
123138
class TestModel(BoringModel):

0 commit comments

Comments
 (0)