Skip to content

Commit 6fc44c9

Browse files
SkafteNickiBordadeependujhapre-commit-ci[bot]
authored
Add missing device id for pytorch 2.8 (#21105)
* add missing device id for pytorch 2.8 * skip device id for older pytorch versions * add testing * fix mypy without touching submodule * fix failing tests * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Deependu <[email protected]> Co-authored-by: Jirka B <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3d56296 commit 6fc44c9

File tree

15 files changed

+136
-20
lines changed

15 files changed

+136
-20
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424

2525
### Fixed
2626

27-
-
27+
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
2828

2929

3030
---
@@ -33,7 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333

3434
### Changed
3535

36-
- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#20913](https://github.com/Lightning-AI/pytorch-lightning/pull/21119))
36+
- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#21119](https://github.com/Lightning-AI/pytorch-lightning/pull/21119))
3737

3838

3939

src/lightning/fabric/strategies/ddp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_sync_ddp_if_available,
4242
)
4343
from lightning.fabric.utilities.distributed import group as _group
44+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
4445
from lightning.fabric.utilities.rank_zero import rank_zero_only
4546

4647
_DDP_FORK_ALIASES = (
@@ -212,7 +213,10 @@ def _setup_distributed(self) -> None:
212213
self._set_world_ranks()
213214
self._process_group_backend = self._get_process_group_backend()
214215
assert self.cluster_environment is not None
215-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
216+
kwargs: dict[str, Any] = {"timeout": self._timeout}
217+
if _TORCH_GREATER_EQUAL_2_3:
218+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
219+
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
216220

217221
def _get_process_group_backend(self) -> str:
218222
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

src/lightning/fabric/strategies/fsdp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,10 @@ def _setup_distributed(self) -> None:
663663
self._set_world_ranks()
664664
self._process_group_backend = self._get_process_group_backend()
665665
assert self.cluster_environment is not None
666-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
666+
kwargs: dict[str, Any] = {"timeout": self._timeout}
667+
if _TORCH_GREATER_EQUAL_2_3:
668+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
669+
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
667670

668671
def _get_process_group_backend(self) -> str:
669672
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

src/lightning/fabric/strategies/model_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,10 @@ def _setup_distributed(self) -> None:
302302
self._set_world_ranks()
303303
self._process_group_backend = self._get_process_group_backend()
304304
assert self.cluster_environment is not None
305-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
305+
kwargs: dict[str, Any] = {"timeout": self._timeout}
306+
if _TORCH_GREATER_EQUAL_2_3:
307+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
308+
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
306309

307310
def _get_process_group_backend(self) -> str:
308311
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131
- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106))
3232

3333

34+
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
35+
3436

3537
---
3638

src/lightning/pytorch/strategies/ddp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
_sync_ddp_if_available,
3737
)
3838
from lightning.fabric.utilities.distributed import group as _group
39-
from lightning.fabric.utilities.imports import _IS_WINDOWS
39+
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_3
4040
from lightning.fabric.utilities.optimizer import _optimizers_to_device
4141
from lightning.fabric.utilities.seed import reset_seed
4242
from lightning.fabric.utilities.types import ReduceOp
@@ -200,7 +200,10 @@ def setup_distributed(self) -> None:
200200
self.set_world_ranks()
201201
self._process_group_backend = self._get_process_group_backend()
202202
assert self.cluster_environment is not None
203-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
203+
kwargs: dict[str, Any] = {"timeout": self._timeout}
204+
if _TORCH_GREATER_EQUAL_2_3:
205+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
206+
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
204207

205208
def _get_process_group_backend(self) -> str:
206209
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
_sync_ddp_if_available,
6262
)
6363
from lightning.fabric.utilities.distributed import group as _group
64-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
64+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3
6565
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
6666
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
6767
from lightning.fabric.utilities.optimizer import _optimizers_to_device
@@ -260,7 +260,10 @@ def setup_environment(self) -> None:
260260

261261
self._process_group_backend = self._get_process_group_backend()
262262
assert self.cluster_environment is not None
263-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
263+
kwargs: dict[str, Any] = {"timeout": self._timeout}
264+
if _TORCH_GREATER_EQUAL_2_3:
265+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
266+
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
264267

265268
# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
266269
if isinstance(self.kwargs.get("device_mesh"), tuple):

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
_sync_ddp_if_available,
4040
)
4141
from lightning.fabric.utilities.distributed import group as _group
42-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
42+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4
4343
from lightning.fabric.utilities.init import _materialize_distributed_module
4444
from lightning.fabric.utilities.load import _METADATA_FILENAME
4545
from lightning.fabric.utilities.optimizer import _optimizers_to_device
@@ -350,7 +350,10 @@ def _setup_distributed(self) -> None:
350350
self.set_world_ranks()
351351
self._process_group_backend = self._get_process_group_backend()
352352
assert self.cluster_environment is not None
353-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
353+
kwargs: dict[str, Any] = {"timeout": self._timeout}
354+
if _TORCH_GREATER_EQUAL_2_3:
355+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
356+
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
354357

355358
def _get_process_group_backend(self) -> str:
356359
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

tests/tests_fabric/strategies/test_ddp.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from lightning.fabric.plugins.environments import LightningEnvironment
2626
from lightning.fabric.strategies import DDPStrategy
2727
from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
28+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
2829
from tests_fabric.helpers.runif import RunIf
2930

3031

@@ -168,6 +169,52 @@ def test_set_timeout(init_process_group_mock):
168169
process_group_backend = strategy._get_process_group_backend()
169170
global_rank = strategy.cluster_environment.global_rank()
170171
world_size = strategy.cluster_environment.world_size()
172+
kwargs = {}
173+
if _TORCH_GREATER_EQUAL_2_3:
174+
kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
171175
init_process_group_mock.assert_called_with(
172-
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
176+
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
177+
)
178+
179+
180+
@mock.patch("torch.distributed.init_process_group")
181+
def test_device_id_passed_for_cuda_devices(init_process_group_mock):
182+
"""Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
183+
# Test with CPU device - device_id should be None
184+
cpu_strategy = DDPStrategy(parallel_devices=[torch.device("cpu")])
185+
cpu_strategy.cluster_environment = LightningEnvironment()
186+
cpu_strategy.accelerator = Mock()
187+
cpu_strategy.setup_environment()
188+
189+
process_group_backend = cpu_strategy._get_process_group_backend()
190+
global_rank = cpu_strategy.cluster_environment.global_rank()
191+
world_size = cpu_strategy.cluster_environment.world_size()
192+
kwargs = {}
193+
if _TORCH_GREATER_EQUAL_2_3:
194+
kwargs["device_id"] = cpu_strategy.root_device if cpu_strategy.root_device.type != "cpu" else None
195+
init_process_group_mock.assert_called_with(
196+
process_group_backend, rank=global_rank, world_size=world_size, timeout=cpu_strategy._timeout, **kwargs
197+
)
198+
199+
init_process_group_mock.reset_mock()
200+
201+
# Test with CUDA device - device_id should be the device
202+
cuda_device = torch.device("cuda", 0)
203+
cuda_strategy = DDPStrategy(parallel_devices=[cuda_device])
204+
cuda_strategy.cluster_environment = LightningEnvironment()
205+
cuda_strategy.accelerator = Mock()
206+
cuda_strategy.setup_environment()
207+
208+
process_group_backend = cuda_strategy._get_process_group_backend()
209+
global_rank = cuda_strategy.cluster_environment.global_rank()
210+
world_size = cuda_strategy.cluster_environment.world_size()
211+
kwargs = {}
212+
if _TORCH_GREATER_EQUAL_2_3:
213+
kwargs["device_id"] = cuda_strategy.root_device if cuda_strategy.root_device.type != "cpu" else None
214+
init_process_group_mock.assert_called_with(
215+
process_group_backend,
216+
rank=global_rank,
217+
world_size=world_size,
218+
timeout=cuda_strategy._timeout,
219+
**kwargs,
173220
)

tests/tests_fabric/strategies/test_fsdp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_get_full_state_dict_context,
3232
_is_sharded_checkpoint,
3333
)
34-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
34+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3
3535

3636

3737
def test_custom_mixed_precision():
@@ -381,8 +381,11 @@ def test_set_timeout(init_process_group_mock):
381381
process_group_backend = strategy._get_process_group_backend()
382382
global_rank = strategy.cluster_environment.global_rank()
383383
world_size = strategy.cluster_environment.world_size()
384+
kwargs = {}
385+
if _TORCH_GREATER_EQUAL_2_3:
386+
kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
384387
init_process_group_mock.assert_called_with(
385-
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
388+
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
386389
)
387390

388391

0 commit comments

Comments
 (0)