File tree Expand file tree Collapse file tree 2 files changed +6
-10
lines changed
src/lightning/pytorch/plugins/precision
tests/tests_pytorch/plugins/precision Expand file tree Collapse file tree 2 files changed +6
-10
lines changed Original file line number Diff line number Diff line change @@ -113,9 +113,7 @@ def clip_gradients(
113113
114114 def autocast_context_manager (self ) -> torch .autocast :
115115 return torch .autocast (
116- self .device ,
117- dtype = (torch .bfloat16 if self .precision == "bf16-mixed" else torch .half ),
118- cache_enabled = False
116+ self .device , dtype = (torch .bfloat16 if self .precision == "bf16-mixed" else torch .half ), cache_enabled = False
119117 )
120118
121119 @override
Original file line number Diff line number Diff line change 1414from unittest .mock import Mock
1515
1616import pytest
17+ import torch
18+ from torch import nn
1719from torch .optim import Optimizer
1820
1921from lightning .pytorch .plugins import MixedPrecision
20- from lightning .pytorch .utilities import GradClipAlgorithmType
21-
22- from torch import nn
23- import torch
24-
2522from lightning .pytorch .plugins .precision import MixedPrecision
23+ from lightning .pytorch .utilities import GradClipAlgorithmType
2624
2725
2826def test_clip_gradients ():
@@ -62,7 +60,7 @@ def test_optimizer_amp_scaling_support_in_step_method():
6260def test_amp_with_no_grad (precision : str ):
6361 layer = nn .Linear (2 , 1 )
6462 x = torch .randn (1 , 2 )
65- amp = MixedPrecision (precision = precision , device = ' cpu' )
63+ amp = MixedPrecision (precision = precision , device = " cpu" )
6664
6765 with amp .autocast_context_manager ():
6866 with torch .no_grad ():
@@ -72,4 +70,4 @@ def test_amp_with_no_grad(precision: str):
7270
7371 loss .backward ()
7472
75- assert loss .grad_fn is not None
73+ assert loss .grad_fn is not None
You can’t perform that action at this time.
0 commit comments