|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +from unittest.mock import Mock |
| 15 | + |
14 | 16 | import torch |
15 | 17 |
|
16 | 18 | from lightning.fabric import seed_everything |
17 | 19 | from lightning.pytorch import Trainer |
18 | 20 | from lightning.pytorch.demos.boring_classes import BoringModel |
| 21 | +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin |
19 | 22 | from tests_pytorch.helpers.runif import RunIf |
20 | 23 |
|
21 | 24 |
|
@@ -54,3 +57,31 @@ def run(fused=False): |
54 | 57 | # Both the regular and the fused version of Adam produce the same losses and model weights |
55 | 58 | for p, q in zip(params, params_fused): |
56 | 59 | torch.testing.assert_close(p, q) |
| 60 | + |
| 61 | + |
| 62 | +@RunIf(min_cuda_gpus=1) |
| 63 | +def test_skip_training_step_with_grad_scaler(): |
| 64 | + """Test that the grad scaler gets skipped when skipping a training step.""" |
| 65 | + |
| 66 | + class TestModel(BoringModel): |
| 67 | + def training_step(self, batch, batch_idx): |
| 68 | + if batch_idx % 2: |
| 69 | + return None # skipping the backward should skip the grad scaler too |
| 70 | + return super().training_step(batch, batch_idx) |
| 71 | + |
| 72 | + trainer = Trainer( |
| 73 | + accelerator="cuda", |
| 74 | + devices=1, |
| 75 | + precision="16-mixed", |
| 76 | + barebones=True, |
| 77 | + max_steps=5, |
| 78 | + gradient_clip_val=0.5, |
| 79 | + ) |
| 80 | + assert isinstance(trainer.precision_plugin, MixedPrecisionPlugin) |
| 81 | + assert trainer.precision_plugin.scaler is not None |
| 82 | + trainer.precision_plugin.scaler = Mock(wraps=trainer.precision_plugin.scaler) |
| 83 | + model = TestModel() |
| 84 | + trainer.fit(model) |
| 85 | + assert trainer.precision_plugin.scaler.unscale_.call_count == 3 |
| 86 | + assert trainer.precision_plugin.scaler.step.call_count == 3 |
| 87 | + assert trainer.precision_plugin.scaler.update.call_count == 3 |
0 commit comments