diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py index 1435e5c2003e1..4a98f12cf6168 100644 --- a/examples/fabric/tensor_parallel/train.py +++ b/examples/fabric/tensor_parallel/train.py @@ -1,14 +1,13 @@ import lightning as L import torch import torch.nn.functional as F +from data import RandomTokenDataset from lightning.fabric.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader -from data import RandomTokenDataset - def train(): strategy = ModelParallelStrategy( diff --git a/examples/pytorch/tensor_parallel/train.py b/examples/pytorch/tensor_parallel/train.py index 37c620f4582f0..6a91e1242e4af 100644 --- a/examples/pytorch/tensor_parallel/train.py +++ b/examples/pytorch/tensor_parallel/train.py @@ -1,14 +1,13 @@ import lightning as L import torch import torch.nn.functional as F +from data import RandomTokenDataset from lightning.pytorch.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader -from data import RandomTokenDataset - class Llama3(L.LightningModule): def __init__(self): diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 03d90cd5df057..1e94fa1166f93 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -18,6 +18,7 @@ import platform from collections.abc import Mapping from contextlib import AbstractContextManager, ExitStack +from datetime import timedelta from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -29,6 +30,7 @@ from typing_extensions import override from lightning.fabric.accelerators import Accelerator, CUDAAccelerator +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.fabric.plugins.precision import Precision from lightning.fabric.strategies.ddp import DDPStrategy @@ -97,6 +99,7 @@ def __init__( load_full_weights: bool = False, precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -241,6 +244,7 @@ def __init__( process_group_backend=process_group_backend, ) self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally + self._timeout: Optional[timedelta] = timeout self.config = self._load_config(config) if self.config is None: @@ -648,7 +652,9 @@ def _init_deepspeed_distributed(self) -> None: f"MEMBER: {self.global_rank + 1}/{self.world_size}" ) self._process_group_backend = self._get_process_group_backend() - deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port) + deepspeed.init_distributed( + self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout + ) def _set_node_environment_variables(self) -> None: assert self.cluster_environment is not None diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4fa771114768d..e17377d4464b0 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -19,6 +19,7 @@ from collections import OrderedDict from collections.abc import Generator, Mapping from contextlib import contextmanager +from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union @@ -30,6 +31,7 @@ import lightning.pytorch as pl from lightning.fabric.plugins import ClusterEnvironment +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.strategies.deepspeed import ( _DEEPSPEED_AVAILABLE, @@ -119,6 +121,7 @@ def __init__( load_full_weights: bool = False, precision_plugin: Optional[Precision] = None, process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -264,6 +267,7 @@ def __init__( precision_plugin=precision_plugin, process_group_backend=process_group_backend, ) + self._timeout: Optional[timedelta] = timeout self.config = self._load_config(config) if self.config is None: @@ -364,7 +368,9 @@ def _init_deepspeed_distributed(self) -> None: f"MEMBER: {self.global_rank + 1}/{self.world_size}" ) self._process_group_backend = self._get_process_group_backend() - deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port) + deepspeed.init_distributed( + self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout + ) def _set_node_environment_variables(self) -> None: assert self.cluster_environment is not None