Skip to content

Commit b2ef2e3

Browse files
baskrahmerpre-commit-ci[bot]Borda
committed
Fix: no_grad with AMP bug (#20921)
* Disable cache for torch.autocast in amp * Add a test * Only test for bf16-mixed * Implement test to reproduce the issue --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 216f9ec)
1 parent 257db2b commit b2ef2e3

File tree

2 files changed

+20
-1
lines changed
  • src/lightning/pytorch/plugins/precision
  • tests/tests_pytorch/plugins/precision

2 files changed

+20
-1
lines changed

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def clip_gradients(
112112
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
113113

114114
def autocast_context_manager(self) -> torch.autocast:
115-
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half))
115+
dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half
116+
return torch.autocast(self.device, dtype=dtype, cache_enabled=False)
116117

117118
@override
118119
@contextmanager

tests/tests_pytorch/plugins/precision/test_amp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from unittest.mock import Mock
1515

1616
import pytest
17+
import torch
18+
from torch import nn
1719
from torch.optim import Optimizer
1820

1921
from lightning.pytorch.plugins import MixedPrecision
@@ -51,3 +53,19 @@ def test_optimizer_amp_scaling_support_in_step_method():
5153

5254
with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
5355
precision.clip_gradients(optimizer, clip_val=1.0)
56+
57+
58+
def test_amp_with_no_grad():
59+
"""Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient
60+
tracking."""
61+
layer = nn.Linear(2, 1)
62+
x = torch.randn(1, 2)
63+
amp = MixedPrecision(precision="bf16-mixed", device="cpu")
64+
65+
with amp.autocast_context_manager():
66+
with torch.no_grad():
67+
_ = layer(x)
68+
69+
loss = layer(x).mean()
70+
loss.backward()
71+
assert loss.grad_fn is not None

0 commit comments

Comments
 (0)