@@ -169,5 +169,44 @@ def test_set_timeout(init_process_group_mock):
169169 global_rank = strategy .cluster_environment .global_rank ()
170170 world_size = strategy .cluster_environment .world_size ()
171171 init_process_group_mock .assert_called_with (
172- process_group_backend , rank = global_rank , world_size = world_size , timeout = test_timedelta
172+ process_group_backend , rank = global_rank , world_size = world_size , timeout = test_timedelta , device_id = None
173+ )
174+
175+
176+ @mock .patch ("torch.distributed.init_process_group" )
177+ def test_device_id_passed_for_cuda_devices (init_process_group_mock ):
178+ """Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
179+ # Test with CPU device - device_id should be None
180+ cpu_strategy = DDPStrategy (parallel_devices = [torch .device ("cpu" )])
181+ cpu_strategy .cluster_environment = LightningEnvironment ()
182+ cpu_strategy .accelerator = Mock ()
183+ cpu_strategy .setup_environment ()
184+
185+ process_group_backend = cpu_strategy ._get_process_group_backend ()
186+ global_rank = cpu_strategy .cluster_environment .global_rank ()
187+ world_size = cpu_strategy .cluster_environment .world_size ()
188+
189+ init_process_group_mock .assert_called_with (
190+ process_group_backend , rank = global_rank , world_size = world_size , timeout = cpu_strategy ._timeout , device_id = None
191+ )
192+
193+ init_process_group_mock .reset_mock ()
194+
195+ # Test with CUDA device - device_id should be the device
196+ cuda_device = torch .device ("cuda" , 0 )
197+ cuda_strategy = DDPStrategy (parallel_devices = [cuda_device ])
198+ cuda_strategy .cluster_environment = LightningEnvironment ()
199+ cuda_strategy .accelerator = Mock ()
200+ cuda_strategy .setup_environment ()
201+
202+ process_group_backend = cuda_strategy ._get_process_group_backend ()
203+ global_rank = cuda_strategy .cluster_environment .global_rank ()
204+ world_size = cuda_strategy .cluster_environment .world_size ()
205+
206+ init_process_group_mock .assert_called_with (
207+ process_group_backend ,
208+ rank = global_rank ,
209+ world_size = world_size ,
210+ timeout = cuda_strategy ._timeout ,
211+ device_id = cuda_device ,
173212 )
0 commit comments