|
14 | 14 | from unittest.mock import Mock |
15 | 15 |
|
16 | 16 | import pytest |
| 17 | +from torch.nn import Module |
17 | 18 | from torch.optim import Optimizer |
18 | 19 |
|
19 | 20 | from lightning.pytorch.plugins import MixedPrecision |
|
22 | 23 |
|
23 | 24 | def test_clip_gradients(): |
24 | 25 | """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" |
| 26 | + module = Mock(spec=Module) |
25 | 27 | optimizer = Mock(spec=Optimizer) |
26 | 28 | precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) |
27 | 29 | precision.clip_grad_by_value = Mock() |
28 | 30 | precision.clip_grad_by_norm = Mock() |
29 | | - precision.clip_gradients(optimizer) |
| 31 | + precision.clip_gradients(module, optimizer) |
30 | 32 | precision.clip_grad_by_value.assert_not_called() |
31 | 33 | precision.clip_grad_by_norm.assert_not_called() |
32 | 34 |
|
33 | | - precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE) |
| 35 | + precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE) |
34 | 36 | precision.clip_grad_by_value.assert_called_once() |
35 | 37 | precision.clip_grad_by_norm.assert_not_called() |
36 | 38 |
|
37 | 39 | precision.clip_grad_by_value.reset_mock() |
38 | 40 | precision.clip_grad_by_norm.reset_mock() |
39 | 41 |
|
40 | | - precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) |
| 42 | + precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) |
41 | 43 | precision.clip_grad_by_value.assert_not_called() |
42 | 44 | precision.clip_grad_by_norm.assert_called_once() |
43 | 45 |
|
|
0 commit comments