Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,12 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
timeout=self._timeout,
device_id=self.root_device if self.root_device.type != "cpu" else None,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,12 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
timeout=self._timeout,
device_id=self.root_device if self.root_device.type != "cpu" else None,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,12 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
timeout=self._timeout,
device_id=self.root_device if self.root_device.type != "cpu" else None,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,12 @@ def setup_distributed(self) -> None:
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
timeout=self._timeout,
device_id=self.root_device if self.root_device.type != "cpu" else None,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,12 @@ def setup_environment(self) -> None:

self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
timeout=self._timeout,
device_id=self.root_device if self.root_device.type != "cpu" else None,
)

# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self.kwargs.get("device_mesh"), tuple):
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,12 @@ def _setup_distributed(self) -> None:
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
_init_dist_connection(
self.cluster_environment,
self._process_group_backend,
timeout=self._timeout,
device_id=self.root_device if self.root_device.type != "cpu" else None,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
Loading