File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
tests/tests_pytorch/plugins/precision Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change 1414from unittest .mock import Mock
1515
1616import pytest
17- from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
1817from torch .nn import Module
1918from torch .optim import Optimizer
2019
2423
2524def test_clip_gradients ():
2625 """Test that `.clip_gradients()` is a no-op when clipping is disabled."""
27- module = FSDP ( Mock (spec = Module ) )
26+ module = Mock (spec = Module )
2827 optimizer = Mock (spec = Optimizer )
2928 precision = MixedPrecision (precision = "16-mixed" , device = "cuda:0" , scaler = Mock ())
3029 precision .clip_grad_by_value = Mock ()
@@ -49,8 +48,9 @@ def test_optimizer_amp_scaling_support_in_step_method():
4948 """Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with
5049 gradient clipping (example: fused Adam)."""
5150
51+ module = Mock (spec = Module )
5252 optimizer = Mock (_step_supports_amp_scaling = True )
5353 precision = MixedPrecision (precision = "16-mixed" , device = "cuda:0" , scaler = Mock ())
5454
5555 with pytest .raises (RuntimeError , match = "The current optimizer.*does not allow for gradient clipping" ):
56- precision .clip_gradients (optimizer , clip_val = 1.0 )
56+ precision .clip_gradients (module , optimizer , clip_val = 1.0 )
You can’t perform that action at this time.
0 commit comments