Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.rank_zero import rank_zero_only

_DDP_FORK_ALIASES = (
Expand Down Expand Up @@ -212,7 +213,14 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
**kwargs,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,14 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
**kwargs,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,14 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
**kwargs,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
11 changes: 9 additions & 2 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import ReduceOp
Expand Down Expand Up @@ -200,7 +200,14 @@ def setup_distributed(self) -> None:
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
**kwargs,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
11 changes: 9 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
from lightning.fabric.utilities.optimizer import _optimizers_to_device
Expand Down Expand Up @@ -260,7 +260,14 @@ def setup_environment(self) -> None:

self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
**kwargs,
)

# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self.kwargs.get("device_mesh"), tuple):
Expand Down
11 changes: 9 additions & 2 deletions src/lightning/pytorch/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.init import _materialize_distributed_module
from lightning.fabric.utilities.load import _METADATA_FILENAME
from lightning.fabric.utilities.optimizer import _optimizers_to_device
Expand Down Expand Up @@ -350,7 +350,14 @@ def _setup_distributed(self) -> None:
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
**kwargs,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
41 changes: 40 additions & 1 deletion tests/tests_fabric/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,44 @@ def test_set_timeout(init_process_group_mock):
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, device_id=None
)


@mock.patch("torch.distributed.init_process_group")
def test_device_id_passed_for_cuda_devices(init_process_group_mock):
"""Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
# Test with CPU device - device_id should be None
cpu_strategy = DDPStrategy(parallel_devices=[torch.device("cpu")])
cpu_strategy.cluster_environment = LightningEnvironment()
cpu_strategy.accelerator = Mock()
cpu_strategy.setup_environment()

process_group_backend = cpu_strategy._get_process_group_backend()
global_rank = cpu_strategy.cluster_environment.global_rank()
world_size = cpu_strategy.cluster_environment.world_size()

init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=cpu_strategy._timeout, device_id=None
)

init_process_group_mock.reset_mock()

# Test with CUDA device - device_id should be the device
cuda_device = torch.device("cuda", 0)
cuda_strategy = DDPStrategy(parallel_devices=[cuda_device])
cuda_strategy.cluster_environment = LightningEnvironment()
cuda_strategy.accelerator = Mock()
cuda_strategy.setup_environment()

process_group_backend = cuda_strategy._get_process_group_backend()
global_rank = cuda_strategy.cluster_environment.global_rank()
world_size = cuda_strategy.cluster_environment.world_size()

init_process_group_mock.assert_called_with(
process_group_backend,
rank=global_rank,
world_size=world_size,
timeout=cuda_strategy._timeout,
device_id=cuda_device,
)
29 changes: 28 additions & 1 deletion tests/tests_pytorch/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,34 @@ def test_set_timeout(mock_init_process_group):
global_rank = trainer.strategy.cluster_environment.global_rank()
world_size = trainer.strategy.cluster_environment.world_size()
mock_init_process_group.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, device_id=None
)


@mock.patch("torch.distributed.init_process_group")
def test_device_id_passed_for_cuda_devices_pytorch(mock_init_process_group):
"""Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
# Test with CPU device - device_id should be None
model = BoringModel()
ddp_strategy = DDPStrategy()
trainer = Trainer(
max_epochs=1,
accelerator="cpu",
strategy=ddp_strategy,
)
trainer.strategy.connect(model)
trainer.lightning_module.trainer = trainer
trainer.strategy.setup_environment()

process_group_backend = trainer.strategy._get_process_group_backend()
global_rank = trainer.strategy.cluster_environment.global_rank()
world_size = trainer.strategy.cluster_environment.world_size()
mock_init_process_group.assert_called_with(
process_group_backend,
rank=global_rank,
world_size=world_size,
timeout=trainer.strategy._timeout,
device_id=None,
)


Expand Down
Loading