Skip to content

Commit 9e46f02

Browse files
committed
add missing device id for pytorch 2.8
1 parent 4ff8ff7 commit 9e46f02

File tree

6 files changed

+36
-6
lines changed

6 files changed

+36
-6
lines changed

src/lightning/fabric/strategies/ddp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,12 @@ def _setup_distributed(self) -> None:
212212
self._set_world_ranks()
213213
self._process_group_backend = self._get_process_group_backend()
214214
assert self.cluster_environment is not None
215-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
215+
_init_dist_connection(
216+
self.cluster_environment,
217+
self._process_group_backend,
218+
timeout=self._timeout,
219+
device_id=self.root_device if self.root_device.type != "cpu" else None,
220+
)
216221

217222
def _get_process_group_backend(self) -> str:
218223
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

src/lightning/fabric/strategies/fsdp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,12 @@ def _setup_distributed(self) -> None:
663663
self._set_world_ranks()
664664
self._process_group_backend = self._get_process_group_backend()
665665
assert self.cluster_environment is not None
666-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
666+
_init_dist_connection(
667+
self.cluster_environment,
668+
self._process_group_backend,
669+
timeout=self._timeout,
670+
device_id=self.root_device if self.root_device.type != "cpu" else None,
671+
)
667672

668673
def _get_process_group_backend(self) -> str:
669674
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

src/lightning/fabric/strategies/model_parallel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,12 @@ def _setup_distributed(self) -> None:
302302
self._set_world_ranks()
303303
self._process_group_backend = self._get_process_group_backend()
304304
assert self.cluster_environment is not None
305-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
305+
_init_dist_connection(
306+
self.cluster_environment,
307+
self._process_group_backend,
308+
timeout=self._timeout,
309+
device_id=self.root_device if self.root_device.type != "cpu" else None,
310+
)
306311

307312
def _get_process_group_backend(self) -> str:
308313
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

src/lightning/pytorch/strategies/ddp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,12 @@ def setup_distributed(self) -> None:
200200
self.set_world_ranks()
201201
self._process_group_backend = self._get_process_group_backend()
202202
assert self.cluster_environment is not None
203-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
203+
_init_dist_connection(
204+
self.cluster_environment,
205+
self._process_group_backend,
206+
timeout=self._timeout,
207+
device_id=self.root_device if self.root_device.type != "cpu" else None,
208+
)
204209

205210
def _get_process_group_backend(self) -> str:
206211
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,12 @@ def setup_environment(self) -> None:
260260

261261
self._process_group_backend = self._get_process_group_backend()
262262
assert self.cluster_environment is not None
263-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
263+
_init_dist_connection(
264+
self.cluster_environment,
265+
self._process_group_backend,
266+
timeout=self._timeout,
267+
device_id=self.root_device if self.root_device.type != "cpu" else None,
268+
)
264269

265270
# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
266271
if isinstance(self.kwargs.get("device_mesh"), tuple):

src/lightning/pytorch/strategies/model_parallel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,12 @@ def _setup_distributed(self) -> None:
350350
self.set_world_ranks()
351351
self._process_group_backend = self._get_process_group_backend()
352352
assert self.cluster_environment is not None
353-
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
353+
_init_dist_connection(
354+
self.cluster_environment,
355+
self._process_group_backend,
356+
timeout=self._timeout,
357+
device_id=self.root_device if self.root_device.type != "cpu" else None,
358+
)
354359

355360
def _get_process_group_backend(self) -> str:
356361
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

0 commit comments

Comments
 (0)