Skip to content

Commit c3e2ba5

Browse files
authored
set_device before init_process_group (#19184)
1 parent 41f76cd commit c3e2ba5

File tree

10 files changed

+24
-11
lines changed

10 files changed

+24
-11
lines changed

src/lightning/fabric/accelerators/cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def _is_ampere_or_later(device: Optional[torch.device] = None) -> bool:
358358

359359
@lru_cache(1) # show the warning only ever once
360360
def _check_cuda_matmul_precision(device: torch.device) -> None:
361-
if not _is_ampere_or_later(device):
361+
if not torch.cuda.is_available() or not _is_ampere_or_later(device):
362362
return
363363
# check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and
364364
# `set_float32_matmul_precision`

src/lightning/fabric/strategies/ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def _configure_launcher(self) -> None:
116116

117117
@override
118118
def setup_environment(self) -> None:
119-
self._setup_distributed()
120119
super().setup_environment()
120+
self._setup_distributed()
121121

122122
@override
123123
def setup_module(self, module: Module) -> DistributedDataParallel:

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,12 +604,16 @@ def _initialize_engine(
604604
return deepspeed_engine, deepspeed_optimizer
605605

606606
@override
607-
def _setup_distributed(self) -> None:
607+
def setup_environment(self) -> None:
608608
if not isinstance(self.accelerator, CUDAAccelerator):
609609
raise RuntimeError(
610610
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
611611
" is used."
612612
)
613+
super().setup_environment()
614+
615+
@override
616+
def _setup_distributed(self) -> None:
613617
assert self.parallel_devices is not None
614618
_validate_device_index_selection(self.parallel_devices)
615619
reset_seed()

src/lightning/fabric/strategies/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ def _configure_launcher(self) -> None:
251251

252252
@override
253253
def setup_environment(self) -> None:
254-
self._setup_distributed()
255254
super().setup_environment()
255+
self._setup_distributed()
256256

257257
@override
258258
def setup_module_and_optimizers(

src/lightning/pytorch/strategies/ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def _configure_launcher(self) -> None:
150150

151151
@override
152152
def setup_environment(self) -> None:
153-
self.setup_distributed()
154153
super().setup_environment()
154+
self.setup_distributed()
155155

156156
@override
157157
def setup(self, trainer: "pl.Trainer") -> None:

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,12 +328,16 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
328328
return config
329329

330330
@override
331-
def setup_distributed(self) -> None:
331+
def setup_environment(self) -> None:
332332
if not isinstance(self.accelerator, CUDAAccelerator):
333333
raise RuntimeError(
334334
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
335335
" is used."
336336
)
337+
super().setup_environment()
338+
339+
@override
340+
def setup_distributed(self) -> None:
337341
assert self.parallel_devices is not None
338342
_validate_device_index_selection(self.parallel_devices)
339343
reset_seed()

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def lightning_restore_optimizer(self) -> bool:
248248

249249
@override
250250
def setup_environment(self) -> None:
251+
super().setup_environment()
251252
log.debug(f"{self.__class__.__name__}: setting up distributed...")
252253
reset_seed()
253254

@@ -257,7 +258,6 @@ def setup_environment(self) -> None:
257258
self._process_group_backend = self._get_process_group_backend()
258259
assert self.cluster_environment is not None
259260
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
260-
super().setup_environment()
261261

262262
def _get_process_group_backend(self) -> str:
263263
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

tests/tests_fabric/accelerators/test_cuda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def test_force_nvml_based_cuda_check():
8989

9090
@mock.patch("torch.cuda.get_device_capability", return_value=(10, 1))
9191
@mock.patch("torch.cuda.get_device_name", return_value="Z100")
92-
def test_tf32_message(_, __, caplog, monkeypatch):
92+
@mock.patch("torch.cuda.is_available", return_value=True)
93+
def test_tf32_message(_, __, ___, caplog, monkeypatch):
9394
# for some reason, caplog doesn't work with our rank_zero_info utilities
9495
monkeypatch.setattr(lightning.fabric.accelerators.cuda, "rank_zero_info", logging.info)
9596

tests/tests_fabric/strategies/test_deepspeed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,12 @@ def test_validate_parallel_devices_indices(device_indices):
400400
DeepSpeed doesn't support it and needs the index to match to the local rank of the process.
401401
402402
"""
403+
accelerator = Mock(spec=CUDAAccelerator)
403404
strategy = DeepSpeedStrategy(
404-
accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices]
405+
accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices]
405406
)
406407
with pytest.raises(
407408
RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes")
408409
):
409410
strategy.setup_environment()
411+
accelerator.setup_device.assert_called_once_with(torch.device("cuda", device_indices[0]))

tests/tests_pytorch/strategies/test_deepspeed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from re import escape
1919
from typing import Any, Dict
2020
from unittest import mock
21-
from unittest.mock import ANY
21+
from unittest.mock import ANY, Mock
2222

2323
import pytest
2424
import torch
@@ -1264,13 +1264,15 @@ def test_validate_parallel_devices_indices(device_indices):
12641264
DeepSpeed doesn't support it and needs the index to match to the local rank of the process.
12651265
12661266
"""
1267+
accelerator = Mock(spec=CUDAAccelerator)
12671268
strategy = DeepSpeedStrategy(
1268-
accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices]
1269+
accelerator=accelerator, parallel_devices=[torch.device("cuda", i) for i in device_indices]
12691270
)
12701271
with pytest.raises(
12711272
RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes")
12721273
):
12731274
strategy.setup_environment()
1275+
accelerator.setup_device.assert_called_once_with(torch.device("cuda", device_indices[0]))
12741276

12751277

12761278
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)

0 commit comments

Comments
 (0)