Skip to content
6 changes: 5 additions & 1 deletion src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
load_full_weights: bool = False,
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -239,6 +240,7 @@ def __init__(
cluster_environment=cluster_environment,
precision=precision,
process_group_backend=process_group_backend,
**kwargs,
)
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally

Expand Down Expand Up @@ -648,7 +650,9 @@ def _init_deepspeed_distributed(self) -> None:
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
)
self._process_group_backend = self._get_process_group_backend()
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
deepspeed.init_distributed(
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
)

def _set_node_environment_variables(self) -> None:
assert self.cluster_environment is not None
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
load_full_weights: bool = False,
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -263,6 +264,7 @@ def __init__(
cluster_environment=cluster_environment,
precision_plugin=precision_plugin,
process_group_backend=process_group_backend,
**kwargs,
)

self.config = self._load_config(config)
Expand Down Expand Up @@ -364,7 +366,9 @@ def _init_deepspeed_distributed(self) -> None:
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
)
self._process_group_backend = self._get_process_group_backend()
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
deepspeed.init_distributed(
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
)

def _set_node_environment_variables(self) -> None:
assert self.cluster_environment is not None
Expand Down
Loading