Skip to content

Commit 3ba4ae7

Browse files
awaelchlilantiga
authored andcommitted
Validate selected device indices in DeepSpeedStrategy (#17952)
(cherry picked from commit 3f4790b)
1 parent 29af389 commit 3ba4ae7

File tree

8 files changed

+61
-48
lines changed

8 files changed

+61
-48
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
## [2.0.5] - 2023-07-07
99

10+
### Added
11+
12+
- Added validation against misconfigured device selection when using the DeepSpeed strategy ([#17952](https://github.com/Lightning-AI/lightning/pull/17952))
13+
14+
1015
### Fixed
1116

1217
- Fixed the emission of a false-positive warning when calling a method on the Fabric-wrapped module that accepts no arguments ([#17875](https://github.com/Lightning-AI/lightning/pull/17875))

src/lightning/fabric/connector.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,6 @@ def _choose_strategy(self) -> Union[Strategy, str]:
376376
if self._num_nodes_flag > 1:
377377
return "ddp"
378378
if len(self._parallel_devices) <= 1:
379-
# TODO: Change this once gpu accelerator was renamed to cuda accelerator
380379
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
381380
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
382381
):

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ def _setup_distributed(self) -> None:
568568
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
569569
" is used."
570570
)
571+
assert self.parallel_devices is not None
572+
_validate_device_index_selection(self.parallel_devices)
571573
reset_seed()
572574
self._set_world_ranks()
573575
rank_zero_only.rank = self.global_rank
@@ -802,3 +804,14 @@ def _validate_state_keys(state: Dict[str, Any]) -> None:
802804
" values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: "
803805
+ ", ".join(colliding_keys)
804806
)
807+
808+
809+
def _validate_device_index_selection(parallel_devices: List[torch.device]) -> None:
810+
selected_device_indices = [device.index for device in parallel_devices]
811+
expected_device_indices = list(range(len(parallel_devices)))
812+
if selected_device_indices != expected_device_indices:
813+
raise RuntimeError(
814+
f"The selected device indices {selected_device_indices!r} don't match the local rank values of processes."
815+
" If you need to select GPUs at a specific index, set the `CUDA_VISIBLE_DEVICES` environment variable"
816+
f" instead. For example: `CUDA_VISIBLE_DEVICES={','.join(str(i) for i in selected_device_indices)}`."
817+
)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313

1414
## [2.0.4] - 2023-06-22
1515

16+
17+
- Added validation against misconfigured device selection when using the DeepSpeed strategy ([#17952](https://github.com/Lightning-AI/lightning/pull/17952))
18+
19+
1620
### Changed
1721

1822
- Changes to the `NeptuneLogger` ([#16761](https://github.com/Lightning-AI/lightning/pull/16761)):

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import lightning.pytorch as pl
3131
from lightning.fabric.plugins import ClusterEnvironment
3232
from lightning.fabric.strategies import _StrategyRegistry
33-
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
33+
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE, _validate_device_index_selection
3434
from lightning.fabric.utilities.optimizer import _optimizers_to_device
3535
from lightning.fabric.utilities.seed import reset_seed
3636
from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
@@ -325,6 +325,8 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
325325
return config
326326

327327
def setup_distributed(self) -> None:
328+
assert self.parallel_devices is not None
329+
_validate_device_index_selection(self.parallel_devices)
328330
reset_seed()
329331
self.set_world_ranks()
330332
rank_zero_only.rank = self.global_rank

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,6 @@ def _choose_strategy(self) -> Union[Strategy, str]:
428428
if self._num_nodes_flag > 1:
429429
return "ddp"
430430
if len(self._parallel_devices) <= 1:
431-
# TODO: Change this once gpu accelerator was renamed to cuda accelerator
432431
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
433432
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
434433
):

tests/tests_fabric/strategies/test_deepspeed.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222
from torch.optim import Optimizer
2323

24-
from lightning.fabric.accelerators import CPUAccelerator
24+
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator
2525
from lightning.fabric.strategies import DeepSpeedStrategy
2626
from tests_fabric.helpers.runif import RunIf
2727

@@ -341,3 +341,19 @@ def test_errors_grad_clipping():
341341
),
342342
):
343343
strategy.clip_gradients_value(Mock(), Mock(), Mock())
344+
345+
346+
@RunIf(deepspeed=True)
347+
@pytest.mark.parametrize("device_indices", [[1], [1, 0], [0, 2], [3, 2, 1]])
348+
def test_validate_parallel_devices_indices(device_indices):
349+
"""Test that the strategy validates that it doesn't support selecting specific devices by index.
350+
351+
DeepSpeed doesn't support it and needs the index to match to the local rank of the process.
352+
"""
353+
strategy = DeepSpeedStrategy(
354+
accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices]
355+
)
356+
with pytest.raises(
357+
RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes")
358+
):
359+
strategy.setup_environment()

tests/tests_pytorch/strategies/test_deepspeed_strategy.py

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import logging
1717
import os
18+
from re import escape
1819
from typing import Any, Dict
1920
from unittest import mock
2021

@@ -26,12 +27,12 @@
2627
from torchmetrics import Accuracy
2728

2829
from lightning.pytorch import LightningDataModule, LightningModule, Trainer
30+
from lightning.pytorch.accelerators import CUDAAccelerator
2931
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
3032
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
3133
from lightning.pytorch.loggers import CSVLogger
3234
from lightning.pytorch.plugins import DeepSpeedPrecisionPlugin
33-
from lightning.pytorch.strategies import DeepSpeedStrategy
34-
from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE
35+
from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE, DeepSpeedStrategy
3536
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3637
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
3738
from tests_pytorch.helpers.datamodules import ClassifDataModule
@@ -1154,48 +1155,6 @@ def test_deepspeed_gradient_clip_by_value(tmpdir):
11541155
trainer.fit(model)
11551156

11561157

1157-
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
1158-
def test_specific_gpu_device_id(tmpdir):
1159-
class TestCallback(Callback):
1160-
def on_train_start(self, *_) -> None:
1161-
assert model.device.index == 1
1162-
1163-
def on_train_batch_start(
1164-
self,
1165-
trainer: Trainer,
1166-
pl_module: LightningModule,
1167-
batch: Any,
1168-
*_,
1169-
) -> None:
1170-
assert batch.device.index == 1
1171-
1172-
def on_test_start(self, *_) -> None:
1173-
assert model.device.index == 1
1174-
1175-
def on_test_batch_start(
1176-
self,
1177-
trainer: Trainer,
1178-
pl_module: LightningModule,
1179-
batch: Any,
1180-
*_,
1181-
) -> None:
1182-
assert batch.device.index == 1
1183-
1184-
model = BoringModel()
1185-
trainer = Trainer(
1186-
default_root_dir=tmpdir,
1187-
fast_dev_run=True,
1188-
accelerator="gpu",
1189-
devices=[1],
1190-
strategy="deepspeed",
1191-
callbacks=TestCallback(),
1192-
enable_progress_bar=False,
1193-
enable_model_summary=False,
1194-
)
1195-
trainer.fit(model)
1196-
trainer.test(model)
1197-
1198-
11991158
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
12001159
def test_deepspeed_multi_save_same_filepath(tmpdir):
12011160
"""Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old
@@ -1306,3 +1265,19 @@ def transfer_batch_to_device(self, batch, *args, **kwargs):
13061265
batch = trainer.strategy.batch_to_device(batch)
13071266
assert batch.is_cuda
13081267
assert batch.dtype is torch.float16
1268+
1269+
1270+
@RunIf(deepspeed=True)
1271+
@pytest.mark.parametrize("device_indices", [[1], [1, 0], [0, 2], [3, 2, 1]])
1272+
def test_validate_parallel_devices_indices(device_indices):
1273+
"""Test that the strategy validates that it doesn't support selecting specific devices by index.
1274+
1275+
DeepSpeed doesn't support it and needs the index to match to the local rank of the process.
1276+
"""
1277+
strategy = DeepSpeedStrategy(
1278+
accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices]
1279+
)
1280+
with pytest.raises(
1281+
RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes")
1282+
):
1283+
strategy.setup_environment()

0 commit comments

Comments
 (0)