We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 064caf7 commit 70023a1Copy full SHA for 70023a1
tests/tests_pytorch/plugins/precision/test_amp.py
@@ -55,13 +55,12 @@ def test_optimizer_amp_scaling_support_in_step_method():
55
precision.clip_gradients(optimizer, clip_val=1.0)
56
57
58
-@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
59
-def test_amp_with_no_grad(precision: str):
+def test_amp_with_no_grad():
60
"""Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient
61
tracking."""
62
layer = nn.Linear(2, 1)
63
x = torch.randn(1, 2)
64
- amp = MixedPrecision(precision=precision, device="cpu")
+ amp = MixedPrecision(precision="bf16-mixed", device="cpu")
65
66
with amp.autocast_context_manager():
67
with torch.no_grad():
0 commit comments