File tree Expand file tree Collapse file tree 6 files changed +28
-15
lines changed Expand file tree Collapse file tree 6 files changed +28
-15
lines changed Original file line number Diff line number Diff line change 41
41
_sync_ddp_if_available ,
42
42
)
43
43
from lightning .fabric .utilities .distributed import group as _group
44
+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_3
44
45
from lightning .fabric .utilities .rank_zero import rank_zero_only
45
46
46
47
_DDP_FORK_ALIASES = (
@@ -212,11 +213,13 @@ def _setup_distributed(self) -> None:
212
213
self ._set_world_ranks ()
213
214
self ._process_group_backend = self ._get_process_group_backend ()
214
215
assert self .cluster_environment is not None
216
+ kwargs = {"timeout" : self ._timeout }
217
+ if _TORCH_GREATER_EQUAL_2_3 :
218
+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
215
219
_init_dist_connection (
216
220
self .cluster_environment ,
217
221
self ._process_group_backend ,
218
- timeout = self ._timeout ,
219
- device_id = self .root_device if self .root_device .type != "cpu" else None ,
222
+ ** kwargs ,
220
223
)
221
224
222
225
def _get_process_group_backend (self ) -> str :
Original file line number Diff line number Diff line change @@ -663,11 +663,13 @@ def _setup_distributed(self) -> None:
663
663
self ._set_world_ranks ()
664
664
self ._process_group_backend = self ._get_process_group_backend ()
665
665
assert self .cluster_environment is not None
666
+ kwargs = {"timeout" : self ._timeout }
667
+ if _TORCH_GREATER_EQUAL_2_3 :
668
+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
666
669
_init_dist_connection (
667
670
self .cluster_environment ,
668
671
self ._process_group_backend ,
669
- timeout = self ._timeout ,
670
- device_id = self .root_device if self .root_device .type != "cpu" else None ,
672
+ ** kwargs ,
671
673
)
672
674
673
675
def _get_process_group_backend (self ) -> str :
Original file line number Diff line number Diff line change @@ -302,11 +302,13 @@ def _setup_distributed(self) -> None:
302
302
self ._set_world_ranks ()
303
303
self ._process_group_backend = self ._get_process_group_backend ()
304
304
assert self .cluster_environment is not None
305
+ kwargs = {"timeout" : self ._timeout }
306
+ if _TORCH_GREATER_EQUAL_2_3 :
307
+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
305
308
_init_dist_connection (
306
309
self .cluster_environment ,
307
310
self ._process_group_backend ,
308
- timeout = self ._timeout ,
309
- device_id = self .root_device if self .root_device .type != "cpu" else None ,
311
+ ** kwargs ,
310
312
)
311
313
312
314
def _get_process_group_backend (self ) -> str :
Original file line number Diff line number Diff line change 36
36
_sync_ddp_if_available ,
37
37
)
38
38
from lightning .fabric .utilities .distributed import group as _group
39
- from lightning .fabric .utilities .imports import _IS_WINDOWS
39
+ from lightning .fabric .utilities .imports import _IS_WINDOWS , _TORCH_GREATER_EQUAL_2_3
40
40
from lightning .fabric .utilities .optimizer import _optimizers_to_device
41
41
from lightning .fabric .utilities .seed import reset_seed
42
42
from lightning .fabric .utilities .types import ReduceOp
@@ -200,11 +200,13 @@ def setup_distributed(self) -> None:
200
200
self .set_world_ranks ()
201
201
self ._process_group_backend = self ._get_process_group_backend ()
202
202
assert self .cluster_environment is not None
203
+ kwargs = {"timeout" : self ._timeout }
204
+ if _TORCH_GREATER_EQUAL_2_3 :
205
+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
203
206
_init_dist_connection (
204
207
self .cluster_environment ,
205
208
self ._process_group_backend ,
206
- timeout = self ._timeout ,
207
- device_id = self .root_device if self .root_device .type != "cpu" else None ,
209
+ ** kwargs ,
208
210
)
209
211
210
212
def _get_process_group_backend (self ) -> str :
Original file line number Diff line number Diff line change 61
61
_sync_ddp_if_available ,
62
62
)
63
63
from lightning .fabric .utilities .distributed import group as _group
64
- from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_2
64
+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_2 , _TORCH_GREATER_EQUAL_2_3
65
65
from lightning .fabric .utilities .init import _has_meta_device_parameters_or_buffers
66
66
from lightning .fabric .utilities .load import _lazy_load , _materialize_tensors
67
67
from lightning .fabric .utilities .optimizer import _optimizers_to_device
@@ -260,11 +260,13 @@ def setup_environment(self) -> None:
260
260
261
261
self ._process_group_backend = self ._get_process_group_backend ()
262
262
assert self .cluster_environment is not None
263
+ kwargs = {"timeout" : self ._timeout }
264
+ if _TORCH_GREATER_EQUAL_2_3 :
265
+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
263
266
_init_dist_connection (
264
267
self .cluster_environment ,
265
268
self ._process_group_backend ,
266
- timeout = self ._timeout ,
267
- device_id = self .root_device if self .root_device .type != "cpu" else None ,
269
+ ** kwargs ,
268
270
)
269
271
270
272
# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
Original file line number Diff line number Diff line change 39
39
_sync_ddp_if_available ,
40
40
)
41
41
from lightning .fabric .utilities .distributed import group as _group
42
- from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_4
42
+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_3 , _TORCH_GREATER_EQUAL_2_4
43
43
from lightning .fabric .utilities .init import _materialize_distributed_module
44
44
from lightning .fabric .utilities .load import _METADATA_FILENAME
45
45
from lightning .fabric .utilities .optimizer import _optimizers_to_device
@@ -350,11 +350,13 @@ def _setup_distributed(self) -> None:
350
350
self .set_world_ranks ()
351
351
self ._process_group_backend = self ._get_process_group_backend ()
352
352
assert self .cluster_environment is not None
353
+ kwargs = {"timeout" : self ._timeout }
354
+ if _TORCH_GREATER_EQUAL_2_3 :
355
+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
353
356
_init_dist_connection (
354
357
self .cluster_environment ,
355
358
self ._process_group_backend ,
356
- timeout = self ._timeout ,
357
- device_id = self .root_device if self .root_device .type != "cpu" else None ,
359
+ ** kwargs ,
358
360
)
359
361
360
362
def _get_process_group_backend (self ) -> str :
You can’t perform that action at this time.
0 commit comments