|
18 | 18 | import platform |
19 | 19 | from collections.abc import Mapping |
20 | 20 | from contextlib import AbstractContextManager, ExitStack |
| 21 | +from datetime import timedelta |
21 | 22 | from itertools import chain |
22 | 23 | from pathlib import Path |
23 | 24 | from typing import TYPE_CHECKING, Any, Callable, Optional, Union |
|
29 | 30 | from typing_extensions import override |
30 | 31 |
|
31 | 32 | from lightning.fabric.accelerators import Accelerator, CUDAAccelerator |
| 33 | +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout |
32 | 34 | from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment |
33 | 35 | from lightning.fabric.plugins.precision import Precision |
34 | 36 | from lightning.fabric.strategies.ddp import DDPStrategy |
@@ -97,6 +99,7 @@ def __init__( |
97 | 99 | load_full_weights: bool = False, |
98 | 100 | precision: Optional[Precision] = None, |
99 | 101 | process_group_backend: Optional[str] = None, |
| 102 | + timeout: Optional[timedelta] = default_pg_timeout, |
100 | 103 | ) -> None: |
101 | 104 | """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large |
102 | 105 | billion parameter models. `For more information: https://pytorch- |
@@ -241,6 +244,7 @@ def __init__( |
241 | 244 | process_group_backend=process_group_backend, |
242 | 245 | ) |
243 | 246 | self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally |
| 247 | + self._timeout: Optional[timedelta] = timeout |
244 | 248 |
|
245 | 249 | self.config = self._load_config(config) |
246 | 250 | if self.config is None: |
@@ -648,7 +652,9 @@ def _init_deepspeed_distributed(self) -> None: |
648 | 652 | f"MEMBER: {self.global_rank + 1}/{self.world_size}" |
649 | 653 | ) |
650 | 654 | self._process_group_backend = self._get_process_group_backend() |
651 | | - deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port) |
| 655 | + deepspeed.init_distributed( |
| 656 | + self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout |
| 657 | + ) |
652 | 658 |
|
653 | 659 | def _set_node_environment_variables(self) -> None: |
654 | 660 | assert self.cluster_environment is not None |
|
0 commit comments