|
15 | 15 | import json |
16 | 16 | import logging |
17 | 17 | import os |
| 18 | +from re import escape |
18 | 19 | from typing import Any, Dict |
19 | 20 | from unittest import mock |
20 | 21 |
|
|
26 | 27 | from torchmetrics import Accuracy |
27 | 28 |
|
28 | 29 | from lightning.pytorch import LightningDataModule, LightningModule, Trainer |
| 30 | +from lightning.pytorch.accelerators import CUDAAccelerator |
29 | 31 | from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint |
30 | 32 | from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset |
31 | 33 | from lightning.pytorch.loggers import CSVLogger |
32 | 34 | from lightning.pytorch.plugins import DeepSpeedPrecisionPlugin |
33 | | -from lightning.pytorch.strategies import DeepSpeedStrategy |
34 | | -from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE |
| 35 | +from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE, DeepSpeedStrategy |
35 | 36 | from lightning.pytorch.utilities.exceptions import MisconfigurationException |
36 | 37 | from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 |
37 | 38 | from tests_pytorch.helpers.datamodules import ClassifDataModule |
@@ -1154,48 +1155,6 @@ def test_deepspeed_gradient_clip_by_value(tmpdir): |
1154 | 1155 | trainer.fit(model) |
1155 | 1156 |
|
1156 | 1157 |
|
1157 | | -@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) |
1158 | | -def test_specific_gpu_device_id(tmpdir): |
1159 | | - class TestCallback(Callback): |
1160 | | - def on_train_start(self, *_) -> None: |
1161 | | - assert model.device.index == 1 |
1162 | | - |
1163 | | - def on_train_batch_start( |
1164 | | - self, |
1165 | | - trainer: Trainer, |
1166 | | - pl_module: LightningModule, |
1167 | | - batch: Any, |
1168 | | - *_, |
1169 | | - ) -> None: |
1170 | | - assert batch.device.index == 1 |
1171 | | - |
1172 | | - def on_test_start(self, *_) -> None: |
1173 | | - assert model.device.index == 1 |
1174 | | - |
1175 | | - def on_test_batch_start( |
1176 | | - self, |
1177 | | - trainer: Trainer, |
1178 | | - pl_module: LightningModule, |
1179 | | - batch: Any, |
1180 | | - *_, |
1181 | | - ) -> None: |
1182 | | - assert batch.device.index == 1 |
1183 | | - |
1184 | | - model = BoringModel() |
1185 | | - trainer = Trainer( |
1186 | | - default_root_dir=tmpdir, |
1187 | | - fast_dev_run=True, |
1188 | | - accelerator="gpu", |
1189 | | - devices=[1], |
1190 | | - strategy="deepspeed", |
1191 | | - callbacks=TestCallback(), |
1192 | | - enable_progress_bar=False, |
1193 | | - enable_model_summary=False, |
1194 | | - ) |
1195 | | - trainer.fit(model) |
1196 | | - trainer.test(model) |
1197 | | - |
1198 | | - |
1199 | 1158 | @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) |
1200 | 1159 | def test_deepspeed_multi_save_same_filepath(tmpdir): |
1201 | 1160 | """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old |
@@ -1306,3 +1265,19 @@ def transfer_batch_to_device(self, batch, *args, **kwargs): |
1306 | 1265 | batch = trainer.strategy.batch_to_device(batch) |
1307 | 1266 | assert batch.is_cuda |
1308 | 1267 | assert batch.dtype is torch.float16 |
| 1268 | + |
| 1269 | + |
| 1270 | +@RunIf(deepspeed=True) |
| 1271 | +@pytest.mark.parametrize("device_indices", [[1], [1, 0], [0, 2], [3, 2, 1]]) |
| 1272 | +def test_validate_parallel_devices_indices(device_indices): |
| 1273 | + """Test that the strategy validates that it doesn't support selecting specific devices by index. |
| 1274 | +
|
| 1275 | + DeepSpeed doesn't support it and needs the index to match to the local rank of the process. |
| 1276 | + """ |
| 1277 | + strategy = DeepSpeedStrategy( |
| 1278 | + accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices] |
| 1279 | + ) |
| 1280 | + with pytest.raises( |
| 1281 | + RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes") |
| 1282 | + ): |
| 1283 | + strategy.setup_environment() |
0 commit comments