File tree Expand file tree Collapse file tree 6 files changed +28
-15
lines changed
Expand file tree Collapse file tree 6 files changed +28
-15
lines changed Original file line number Diff line number Diff line change 4141 _sync_ddp_if_available ,
4242)
4343from lightning .fabric .utilities .distributed import group as _group
44+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_3
4445from lightning .fabric .utilities .rank_zero import rank_zero_only
4546
4647_DDP_FORK_ALIASES = (
@@ -212,11 +213,13 @@ def _setup_distributed(self) -> None:
212213 self ._set_world_ranks ()
213214 self ._process_group_backend = self ._get_process_group_backend ()
214215 assert self .cluster_environment is not None
216+ kwargs = {"timeout" : self ._timeout }
217+ if _TORCH_GREATER_EQUAL_2_3 :
218+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
215219 _init_dist_connection (
216220 self .cluster_environment ,
217221 self ._process_group_backend ,
218- timeout = self ._timeout ,
219- device_id = self .root_device if self .root_device .type != "cpu" else None ,
222+ ** kwargs ,
220223 )
221224
222225 def _get_process_group_backend (self ) -> str :
Original file line number Diff line number Diff line change @@ -663,11 +663,13 @@ def _setup_distributed(self) -> None:
663663 self ._set_world_ranks ()
664664 self ._process_group_backend = self ._get_process_group_backend ()
665665 assert self .cluster_environment is not None
666+ kwargs = {"timeout" : self ._timeout }
667+ if _TORCH_GREATER_EQUAL_2_3 :
668+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
666669 _init_dist_connection (
667670 self .cluster_environment ,
668671 self ._process_group_backend ,
669- timeout = self ._timeout ,
670- device_id = self .root_device if self .root_device .type != "cpu" else None ,
672+ ** kwargs ,
671673 )
672674
673675 def _get_process_group_backend (self ) -> str :
Original file line number Diff line number Diff line change @@ -302,11 +302,13 @@ def _setup_distributed(self) -> None:
302302 self ._set_world_ranks ()
303303 self ._process_group_backend = self ._get_process_group_backend ()
304304 assert self .cluster_environment is not None
305+ kwargs = {"timeout" : self ._timeout }
306+ if _TORCH_GREATER_EQUAL_2_3 :
307+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
305308 _init_dist_connection (
306309 self .cluster_environment ,
307310 self ._process_group_backend ,
308- timeout = self ._timeout ,
309- device_id = self .root_device if self .root_device .type != "cpu" else None ,
311+ ** kwargs ,
310312 )
311313
312314 def _get_process_group_backend (self ) -> str :
Original file line number Diff line number Diff line change 3636 _sync_ddp_if_available ,
3737)
3838from lightning .fabric .utilities .distributed import group as _group
39- from lightning .fabric .utilities .imports import _IS_WINDOWS
39+ from lightning .fabric .utilities .imports import _IS_WINDOWS , _TORCH_GREATER_EQUAL_2_3
4040from lightning .fabric .utilities .optimizer import _optimizers_to_device
4141from lightning .fabric .utilities .seed import reset_seed
4242from lightning .fabric .utilities .types import ReduceOp
@@ -200,11 +200,13 @@ def setup_distributed(self) -> None:
200200 self .set_world_ranks ()
201201 self ._process_group_backend = self ._get_process_group_backend ()
202202 assert self .cluster_environment is not None
203+ kwargs = {"timeout" : self ._timeout }
204+ if _TORCH_GREATER_EQUAL_2_3 :
205+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
203206 _init_dist_connection (
204207 self .cluster_environment ,
205208 self ._process_group_backend ,
206- timeout = self ._timeout ,
207- device_id = self .root_device if self .root_device .type != "cpu" else None ,
209+ ** kwargs ,
208210 )
209211
210212 def _get_process_group_backend (self ) -> str :
Original file line number Diff line number Diff line change 6161 _sync_ddp_if_available ,
6262)
6363from lightning .fabric .utilities .distributed import group as _group
64- from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_2
64+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_2 , _TORCH_GREATER_EQUAL_2_3
6565from lightning .fabric .utilities .init import _has_meta_device_parameters_or_buffers
6666from lightning .fabric .utilities .load import _lazy_load , _materialize_tensors
6767from lightning .fabric .utilities .optimizer import _optimizers_to_device
@@ -260,11 +260,13 @@ def setup_environment(self) -> None:
260260
261261 self ._process_group_backend = self ._get_process_group_backend ()
262262 assert self .cluster_environment is not None
263+ kwargs = {"timeout" : self ._timeout }
264+ if _TORCH_GREATER_EQUAL_2_3 :
265+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
263266 _init_dist_connection (
264267 self .cluster_environment ,
265268 self ._process_group_backend ,
266- timeout = self ._timeout ,
267- device_id = self .root_device if self .root_device .type != "cpu" else None ,
269+ ** kwargs ,
268270 )
269271
270272 # if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
Original file line number Diff line number Diff line change 3939 _sync_ddp_if_available ,
4040)
4141from lightning .fabric .utilities .distributed import group as _group
42- from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_4
42+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_3 , _TORCH_GREATER_EQUAL_2_4
4343from lightning .fabric .utilities .init import _materialize_distributed_module
4444from lightning .fabric .utilities .load import _METADATA_FILENAME
4545from lightning .fabric .utilities .optimizer import _optimizers_to_device
@@ -350,11 +350,13 @@ def _setup_distributed(self) -> None:
350350 self .set_world_ranks ()
351351 self ._process_group_backend = self ._get_process_group_backend ()
352352 assert self .cluster_environment is not None
353+ kwargs = {"timeout" : self ._timeout }
354+ if _TORCH_GREATER_EQUAL_2_3 :
355+ kwargs ["device_id" ] = self .root_device if self .root_device .type != "cpu" else None
353356 _init_dist_connection (
354357 self .cluster_environment ,
355358 self ._process_group_backend ,
356- timeout = self ._timeout ,
357- device_id = self .root_device if self .root_device .type != "cpu" else None ,
359+ ** kwargs ,
358360 )
359361
360362 def _get_process_group_backend (self ) -> str :
You can’t perform that action at this time.
0 commit comments