Skip to content

Commit c253280

Browse files
committed
skip device id for older pytorch versions
1 parent 242b40e commit c253280

File tree

6 files changed

+28
-15
lines changed

6 files changed

+28
-15
lines changed

src/lightning/fabric/strategies/ddp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_sync_ddp_if_available,
4242
)
4343
from lightning.fabric.utilities.distributed import group as _group
44+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
4445
from lightning.fabric.utilities.rank_zero import rank_zero_only
4546

4647
_DDP_FORK_ALIASES = (
@@ -212,11 +213,13 @@ def _setup_distributed(self) -> None:
212213
self._set_world_ranks()
213214
self._process_group_backend = self._get_process_group_backend()
214215
assert self.cluster_environment is not None
216+
kwargs = {"timeout": self._timeout}
217+
if _TORCH_GREATER_EQUAL_2_3:
218+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
215219
_init_dist_connection(
216220
self.cluster_environment,
217221
self._process_group_backend,
218-
timeout=self._timeout,
219-
device_id=self.root_device if self.root_device.type != "cpu" else None,
222+
**kwargs,
220223
)
221224

222225
def _get_process_group_backend(self) -> str:

src/lightning/fabric/strategies/fsdp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,11 +663,13 @@ def _setup_distributed(self) -> None:
663663
self._set_world_ranks()
664664
self._process_group_backend = self._get_process_group_backend()
665665
assert self.cluster_environment is not None
666+
kwargs = {"timeout": self._timeout}
667+
if _TORCH_GREATER_EQUAL_2_3:
668+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
666669
_init_dist_connection(
667670
self.cluster_environment,
668671
self._process_group_backend,
669-
timeout=self._timeout,
670-
device_id=self.root_device if self.root_device.type != "cpu" else None,
672+
**kwargs,
671673
)
672674

673675
def _get_process_group_backend(self) -> str:

src/lightning/fabric/strategies/model_parallel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,13 @@ def _setup_distributed(self) -> None:
302302
self._set_world_ranks()
303303
self._process_group_backend = self._get_process_group_backend()
304304
assert self.cluster_environment is not None
305+
kwargs = {"timeout": self._timeout}
306+
if _TORCH_GREATER_EQUAL_2_3:
307+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
305308
_init_dist_connection(
306309
self.cluster_environment,
307310
self._process_group_backend,
308-
timeout=self._timeout,
309-
device_id=self.root_device if self.root_device.type != "cpu" else None,
311+
**kwargs,
310312
)
311313

312314
def _get_process_group_backend(self) -> str:

src/lightning/pytorch/strategies/ddp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
_sync_ddp_if_available,
3737
)
3838
from lightning.fabric.utilities.distributed import group as _group
39-
from lightning.fabric.utilities.imports import _IS_WINDOWS
39+
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_3
4040
from lightning.fabric.utilities.optimizer import _optimizers_to_device
4141
from lightning.fabric.utilities.seed import reset_seed
4242
from lightning.fabric.utilities.types import ReduceOp
@@ -200,11 +200,13 @@ def setup_distributed(self) -> None:
200200
self.set_world_ranks()
201201
self._process_group_backend = self._get_process_group_backend()
202202
assert self.cluster_environment is not None
203+
kwargs = {"timeout": self._timeout}
204+
if _TORCH_GREATER_EQUAL_2_3:
205+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
203206
_init_dist_connection(
204207
self.cluster_environment,
205208
self._process_group_backend,
206-
timeout=self._timeout,
207-
device_id=self.root_device if self.root_device.type != "cpu" else None,
209+
**kwargs,
208210
)
209211

210212
def _get_process_group_backend(self) -> str:

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
_sync_ddp_if_available,
6262
)
6363
from lightning.fabric.utilities.distributed import group as _group
64-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
64+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3
6565
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
6666
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
6767
from lightning.fabric.utilities.optimizer import _optimizers_to_device
@@ -260,11 +260,13 @@ def setup_environment(self) -> None:
260260

261261
self._process_group_backend = self._get_process_group_backend()
262262
assert self.cluster_environment is not None
263+
kwargs = {"timeout": self._timeout}
264+
if _TORCH_GREATER_EQUAL_2_3:
265+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
263266
_init_dist_connection(
264267
self.cluster_environment,
265268
self._process_group_backend,
266-
timeout=self._timeout,
267-
device_id=self.root_device if self.root_device.type != "cpu" else None,
269+
**kwargs,
268270
)
269271

270272
# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
_sync_ddp_if_available,
4040
)
4141
from lightning.fabric.utilities.distributed import group as _group
42-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
42+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4
4343
from lightning.fabric.utilities.init import _materialize_distributed_module
4444
from lightning.fabric.utilities.load import _METADATA_FILENAME
4545
from lightning.fabric.utilities.optimizer import _optimizers_to_device
@@ -350,11 +350,13 @@ def _setup_distributed(self) -> None:
350350
self.set_world_ranks()
351351
self._process_group_backend = self._get_process_group_backend()
352352
assert self.cluster_environment is not None
353+
kwargs = {"timeout": self._timeout}
354+
if _TORCH_GREATER_EQUAL_2_3:
355+
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
353356
_init_dist_connection(
354357
self.cluster_environment,
355358
self._process_group_backend,
356-
timeout=self._timeout,
357-
device_id=self.root_device if self.root_device.type != "cpu" else None,
359+
**kwargs,
358360
)
359361

360362
def _get_process_group_backend(self) -> str:

0 commit comments

Comments
 (0)