Skip to content

Commit 09ec21c

Browse files
committed
make timeout explicit in DeepSpeedStrategy
1 parent ad74bb3 commit 09ec21c

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import platform
1919
from collections.abc import Mapping
2020
from contextlib import AbstractContextManager, ExitStack
21+
from datetime import timedelta
2122
from itertools import chain
2223
from pathlib import Path
2324
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@@ -31,6 +32,7 @@
3132
from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
3233
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
3334
from lightning.fabric.plugins.precision import Precision
35+
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
3436
from lightning.fabric.strategies.ddp import DDPStrategy
3537
from lightning.fabric.strategies.registry import _StrategyRegistry
3638
from lightning.fabric.strategies.strategy import _Sharded
@@ -97,7 +99,7 @@ def __init__(
9799
load_full_weights: bool = False,
98100
precision: Optional[Precision] = None,
99101
process_group_backend: Optional[str] = None,
100-
**kwargs: Any,
102+
timeout: Optional[timedelta] = default_pg_timeout,
101103
) -> None:
102104
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
103105
billion parameter models. `For more information: https://pytorch-
@@ -240,9 +242,9 @@ def __init__(
240242
cluster_environment=cluster_environment,
241243
precision=precision,
242244
process_group_backend=process_group_backend,
243-
**kwargs,
244245
)
245246
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
247+
self._timeout: Optional[timedelta] = timeout
246248

247249
self.config = self._load_config(config)
248250
if self.config is None:
@@ -650,7 +652,7 @@ def _init_deepspeed_distributed(self) -> None:
650652
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
651653
)
652654
self._process_group_backend = self._get_process_group_backend()
653-
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
655+
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout)
654656

655657
def _set_node_environment_variables(self) -> None:
656658
assert self.cluster_environment is not None

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections import OrderedDict
2020
from collections.abc import Generator, Mapping
2121
from contextlib import contextmanager
22+
from datetime import timedelta
2223
from pathlib import Path
2324
from typing import TYPE_CHECKING, Any, Optional, Union
2425

@@ -30,6 +31,9 @@
3031

3132
import lightning.pytorch as pl
3233
from lightning.fabric.plugins import ClusterEnvironment
34+
from lightning.fabric.plugins.collectives.torch_collective import (
35+
default_pg_timeout
36+
)
3337
from lightning.fabric.strategies import _StrategyRegistry
3438
from lightning.fabric.strategies.deepspeed import (
3539
_DEEPSPEED_AVAILABLE,
@@ -119,7 +123,7 @@ def __init__(
119123
load_full_weights: bool = False,
120124
precision_plugin: Optional[Precision] = None,
121125
process_group_backend: Optional[str] = None,
122-
**kwargs: Any,
126+
timeout: Optional[timedelta] = default_pg_timeout,
123127
) -> None:
124128
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
125129
billion parameter models. `For more information: https://pytorch-
@@ -264,8 +268,8 @@ def __init__(
264268
cluster_environment=cluster_environment,
265269
precision_plugin=precision_plugin,
266270
process_group_backend=process_group_backend,
267-
**kwargs,
268271
)
272+
self._timeout: Optional[timedelta] = timeout
269273

270274
self.config = self._load_config(config)
271275
if self.config is None:
@@ -366,7 +370,7 @@ def _init_deepspeed_distributed(self) -> None:
366370
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
367371
)
368372
self._process_group_backend = self._get_process_group_backend()
369-
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
373+
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout)
370374

371375
def _set_node_environment_variables(self) -> None:
372376
assert self.cluster_environment is not None

0 commit comments

Comments
 (0)