|
18 | 18 | import platform |
19 | 19 | from collections.abc import Mapping |
20 | 20 | from contextlib import AbstractContextManager, ExitStack |
| 21 | +from datetime import timedelta |
21 | 22 | from itertools import chain |
22 | 23 | from pathlib import Path |
23 | 24 | from typing import TYPE_CHECKING, Any, Callable, Optional, Union |
|
31 | 32 | from lightning.fabric.accelerators import Accelerator, CUDAAccelerator |
32 | 33 | from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment |
33 | 34 | from lightning.fabric.plugins.precision import Precision |
| 35 | +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout |
34 | 36 | from lightning.fabric.strategies.ddp import DDPStrategy |
35 | 37 | from lightning.fabric.strategies.registry import _StrategyRegistry |
36 | 38 | from lightning.fabric.strategies.strategy import _Sharded |
@@ -97,7 +99,7 @@ def __init__( |
97 | 99 | load_full_weights: bool = False, |
98 | 100 | precision: Optional[Precision] = None, |
99 | 101 | process_group_backend: Optional[str] = None, |
100 | | - **kwargs: Any, |
| 102 | + timeout: Optional[timedelta] = default_pg_timeout, |
101 | 103 | ) -> None: |
102 | 104 | """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large |
103 | 105 | billion parameter models. `For more information: https://pytorch- |
@@ -240,9 +242,9 @@ def __init__( |
240 | 242 | cluster_environment=cluster_environment, |
241 | 243 | precision=precision, |
242 | 244 | process_group_backend=process_group_backend, |
243 | | - **kwargs, |
244 | 245 | ) |
245 | 246 | self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally |
| 247 | + self._timeout: Optional[timedelta] = timeout |
246 | 248 |
|
247 | 249 | self.config = self._load_config(config) |
248 | 250 | if self.config is None: |
@@ -650,7 +652,7 @@ def _init_deepspeed_distributed(self) -> None: |
650 | 652 | f"MEMBER: {self.global_rank + 1}/{self.world_size}" |
651 | 653 | ) |
652 | 654 | self._process_group_backend = self._get_process_group_backend() |
653 | | - deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port) |
| 655 | + deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout) |
654 | 656 |
|
655 | 657 | def _set_node_environment_variables(self) -> None: |
656 | 658 | assert self.cluster_environment is not None |
|
0 commit comments