File tree Expand file tree Collapse file tree 1 file changed +22
-0
lines changed
tests/tests_pytorch/plugins/precision Expand file tree Collapse file tree 1 file changed +22
-0
lines changed Original file line number Diff line number Diff line change 1919from lightning .pytorch .plugins import MixedPrecision
2020from lightning .pytorch .utilities import GradClipAlgorithmType
2121
22+ from torch import nn
23+ import torch
24+
25+ from lightning .pytorch .plugins .precision import MixedPrecision
26+
2227
2328def test_clip_gradients ():
2429 """Test that `.clip_gradients()` is a no-op when clipping is disabled."""
@@ -51,3 +56,20 @@ def test_optimizer_amp_scaling_support_in_step_method():
5156
5257 with pytest .raises (RuntimeError , match = "The current optimizer.*does not allow for gradient clipping" ):
5358 precision .clip_gradients (optimizer , clip_val = 1.0 )
59+
60+
61+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
62+ def test_amp_with_no_grad (precision : str ):
63+ layer = nn .Linear (2 , 1 )
64+ x = torch .randn (1 , 2 )
65+ amp = MixedPrecision (precision = precision , device = 'cpu' )
66+
67+ with amp .autocast_context_manager ():
68+ with torch .no_grad ():
69+ _ = layer (x )
70+
71+ loss = layer (x ).mean ()
72+
73+ loss .backward ()
74+
75+ assert loss .grad_fn is not None
You can’t perform that action at this time.
0 commit comments