We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 600280b commit 14ae4d8Copy full SHA for 14ae4d8
tests/tests_pytorch/plugins/precision/test_amp.py
@@ -57,6 +57,8 @@ def test_optimizer_amp_scaling_support_in_step_method():
57
58
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
59
def test_amp_with_no_grad(precision: str):
60
+ """Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient
61
+ tracking."""
62
layer = nn.Linear(2, 1)
63
x = torch.randn(1, 2)
64
amp = MixedPrecision(precision=precision, device="cpu")
@@ -66,7 +68,5 @@ def test_amp_with_no_grad(precision: str):
66
68
_ = layer(x)
67
69
70
loss = layer(x).mean()
-
71
loss.backward()
72
assert loss.grad_fn is not None
0 commit comments