Skip to content

Commit 28e0bdd

Browse files
0x404awaelchli
authored andcommitted
Explicitly enable grad in closure (#18268)
Co-authored-by: awaelchli <[email protected]> (cherry picked from commit b88b8b3)
1 parent cfefd09 commit 28e0bdd

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))
2525

2626

27+
- Ensure that the closure running inside the optimizer step has gradients enabled, even if the optimizer step has it disabled ([#18268](https://github.com/Lightning-AI/lightning/pull/18268))
28+
29+
2730
## [2.0.5] - 2023-07-07
2831

2932
### Fixed

src/lightning/pytorch/loops/optimization/automatic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
self._backward_fn = backward_fn
124124
self._zero_grad_fn = zero_grad_fn
125125

126+
@torch.enable_grad()
126127
def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
127128
step_output = self._step_fn()
128129

tests/tests_pytorch/loops/optimization/test_closure.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,31 @@ def step(self, closure=None):
4343
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
4444
with pytest.raises(MisconfigurationException, match="The closure hasn't been executed"):
4545
trainer.fit(model)
46+
47+
48+
def test_closure_with_no_grad_optimizer(tmpdir):
49+
"""Test that the closure is guaranteed to run with grad enabled.
50+
51+
There are certain third-party library optimizers
52+
(such as Hugging Face Transformers' AdamW) that set `no_grad` during the `step` operation.
53+
54+
"""
55+
56+
class NoGradAdamW(torch.optim.AdamW):
57+
@torch.no_grad()
58+
def step(self, closure):
59+
if closure is not None:
60+
closure()
61+
return super().step()
62+
63+
class TestModel(BoringModel):
64+
def training_step(self, batch, batch_idx):
65+
assert torch.is_grad_enabled()
66+
return super().training_step(batch, batch_idx)
67+
68+
def configure_optimizers(self):
69+
return NoGradAdamW(self.parameters(), lr=0.1)
70+
71+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
72+
model = TestModel()
73+
trainer.fit(model)

0 commit comments

Comments
 (0)