Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def clip_gradients(
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)

def autocast_context_manager(self) -> torch.autocast:
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half))
return torch.autocast(
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), cache_enabled=False
)

@override
@contextmanager
Expand Down
18 changes: 18 additions & 0 deletions tests/tests_pytorch/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from unittest.mock import Mock

import pytest
import torch
from torch import nn
from torch.optim import Optimizer

from lightning.pytorch.plugins import MixedPrecision
Expand Down Expand Up @@ -51,3 +53,19 @@ def test_optimizer_amp_scaling_support_in_step_method():

with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
precision.clip_gradients(optimizer, clip_val=1.0)


def test_amp_with_no_grad():
"""Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient
tracking."""
layer = nn.Linear(2, 1)
x = torch.randn(1, 2)
amp = MixedPrecision(precision="bf16-mixed", device="cpu")

with amp.autocast_context_manager():
with torch.no_grad():
_ = layer(x)

loss = layer(x).mean()
loss.backward()
assert loss.grad_fn is not None
Loading