|
25 | 25 | from lightning.fabric.plugins.environments import LightningEnvironment
|
26 | 26 | from lightning.fabric.strategies import DDPStrategy
|
27 | 27 | from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
|
| 28 | +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 |
28 | 29 | from tests_fabric.helpers.runif import RunIf
|
29 | 30 |
|
30 | 31 |
|
@@ -168,6 +169,52 @@ def test_set_timeout(init_process_group_mock):
|
168 | 169 | process_group_backend = strategy._get_process_group_backend()
|
169 | 170 | global_rank = strategy.cluster_environment.global_rank()
|
170 | 171 | world_size = strategy.cluster_environment.world_size()
|
| 172 | + kwargs = {} |
| 173 | + if _TORCH_GREATER_EQUAL_2_3: |
| 174 | + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None |
171 | 175 | init_process_group_mock.assert_called_with(
|
172 |
| - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta |
| 176 | + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs |
| 177 | + ) |
| 178 | + |
| 179 | + |
| 180 | +@mock.patch("torch.distributed.init_process_group") |
| 181 | +def test_device_id_passed_for_cuda_devices(init_process_group_mock): |
| 182 | + """Test that device_id is passed to init_process_group for CUDA devices but not for CPU.""" |
| 183 | + # Test with CPU device - device_id should be None |
| 184 | + cpu_strategy = DDPStrategy(parallel_devices=[torch.device("cpu")]) |
| 185 | + cpu_strategy.cluster_environment = LightningEnvironment() |
| 186 | + cpu_strategy.accelerator = Mock() |
| 187 | + cpu_strategy.setup_environment() |
| 188 | + |
| 189 | + process_group_backend = cpu_strategy._get_process_group_backend() |
| 190 | + global_rank = cpu_strategy.cluster_environment.global_rank() |
| 191 | + world_size = cpu_strategy.cluster_environment.world_size() |
| 192 | + kwargs = {} |
| 193 | + if _TORCH_GREATER_EQUAL_2_3: |
| 194 | + kwargs["device_id"] = cpu_strategy.root_device if cpu_strategy.root_device.type != "cpu" else None |
| 195 | + init_process_group_mock.assert_called_with( |
| 196 | + process_group_backend, rank=global_rank, world_size=world_size, timeout=cpu_strategy._timeout, **kwargs |
| 197 | + ) |
| 198 | + |
| 199 | + init_process_group_mock.reset_mock() |
| 200 | + |
| 201 | + # Test with CUDA device - device_id should be the device |
| 202 | + cuda_device = torch.device("cuda", 0) |
| 203 | + cuda_strategy = DDPStrategy(parallel_devices=[cuda_device]) |
| 204 | + cuda_strategy.cluster_environment = LightningEnvironment() |
| 205 | + cuda_strategy.accelerator = Mock() |
| 206 | + cuda_strategy.setup_environment() |
| 207 | + |
| 208 | + process_group_backend = cuda_strategy._get_process_group_backend() |
| 209 | + global_rank = cuda_strategy.cluster_environment.global_rank() |
| 210 | + world_size = cuda_strategy.cluster_environment.world_size() |
| 211 | + kwargs = {} |
| 212 | + if _TORCH_GREATER_EQUAL_2_3: |
| 213 | + kwargs["device_id"] = cuda_strategy.root_device if cuda_strategy.root_device.type != "cpu" else None |
| 214 | + init_process_group_mock.assert_called_with( |
| 215 | + process_group_backend, |
| 216 | + rank=global_rank, |
| 217 | + world_size=world_size, |
| 218 | + timeout=cuda_strategy._timeout, |
| 219 | + **kwargs, |
173 | 220 | )
|
0 commit comments