@@ -169,5 +169,44 @@ def test_set_timeout(init_process_group_mock):
169
169
global_rank = strategy .cluster_environment .global_rank ()
170
170
world_size = strategy .cluster_environment .world_size ()
171
171
init_process_group_mock .assert_called_with (
172
- process_group_backend , rank = global_rank , world_size = world_size , timeout = test_timedelta
172
+ process_group_backend , rank = global_rank , world_size = world_size , timeout = test_timedelta , device_id = None
173
+ )
174
+
175
+
176
+ @mock .patch ("torch.distributed.init_process_group" )
177
+ def test_device_id_passed_for_cuda_devices (init_process_group_mock ):
178
+ """Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
179
+ # Test with CPU device - device_id should be None
180
+ cpu_strategy = DDPStrategy (parallel_devices = [torch .device ("cpu" )])
181
+ cpu_strategy .cluster_environment = LightningEnvironment ()
182
+ cpu_strategy .accelerator = Mock ()
183
+ cpu_strategy .setup_environment ()
184
+
185
+ process_group_backend = cpu_strategy ._get_process_group_backend ()
186
+ global_rank = cpu_strategy .cluster_environment .global_rank ()
187
+ world_size = cpu_strategy .cluster_environment .world_size ()
188
+
189
+ init_process_group_mock .assert_called_with (
190
+ process_group_backend , rank = global_rank , world_size = world_size , timeout = cpu_strategy ._timeout , device_id = None
191
+ )
192
+
193
+ init_process_group_mock .reset_mock ()
194
+
195
+ # Test with CUDA device - device_id should be the device
196
+ cuda_device = torch .device ("cuda" , 0 )
197
+ cuda_strategy = DDPStrategy (parallel_devices = [cuda_device ])
198
+ cuda_strategy .cluster_environment = LightningEnvironment ()
199
+ cuda_strategy .accelerator = Mock ()
200
+ cuda_strategy .setup_environment ()
201
+
202
+ process_group_backend = cuda_strategy ._get_process_group_backend ()
203
+ global_rank = cuda_strategy .cluster_environment .global_rank ()
204
+ world_size = cuda_strategy .cluster_environment .world_size ()
205
+
206
+ init_process_group_mock .assert_called_with (
207
+ process_group_backend ,
208
+ rank = global_rank ,
209
+ world_size = world_size ,
210
+ timeout = cuda_strategy ._timeout ,
211
+ device_id = cuda_device ,
173
212
)
0 commit comments