Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
- name: Run tests
run: pytest -v -ra .

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]"
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
- name: Build the documentation
run: mkdocs build

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.data_loader import SampledDatasetIterator
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.dataset.gpt.config import GPTSamplingData
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.engine.config_utils.run import log_main_rank
Expand Down
File renamed without changes.
140 changes: 73 additions & 67 deletions fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,31 @@ def setup(self, group: "ProcessGroup|None"):
def check_ranks_in_range(self, start, stop):
check_ranks_in_range(self.global_ranks, start, stop)

@classmethod
def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: tuple[int, int]) -> typing.Self:
start = global_rank
rank = 0
world_size = 1
for i, (size, stride) in enumerate(sizes_and_strides):
if i > 0:
Assert.multiple(stride, sizes_and_strides[i - 1][1])
rank_ = global_rank // stride % size
start -= rank_ * stride
rank += world_size * rank_
world_size *= size
global_ranks = [start]
for size, stride in sizes_and_strides:
if size == 1:
continue
if len(global_ranks) == 1:
global_ranks = range(start, start + size * stride, stride)
elif isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start:
global_ranks = range(start, start + size * stride, global_ranks.step)
else:
global_ranks = [rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks]
Assert.eq(len(global_ranks), world_size)
return DistributedDim(name=name, size=world_size, rank=rank, global_ranks=global_ranks)


def check_ranks_in_range(global_ranks, start, stop):
Assert.geq(min(global_ranks), start)
Expand All @@ -112,6 +137,7 @@ class DistributedDimNames:
sequence_data = "sequence_data"
batch_data = "batch_data"
tensor_and_sequence_data = "tensor_and_sequence_data"
model_and_sequence_data = "model_and_sequence_data"
tensor_and_data = "tensor_and_data"


Expand Down Expand Up @@ -300,88 +326,68 @@ def _validate(self) -> None:
else:
self.distributed_dims = {}

data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1)
tensor_stride = 1
sequence_data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1)
batch_data_stride = sequence_data_stride * self.sequence_data_parallel
pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel)

self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.world,
size=self.world_size,
rank=self.rank,
global_ranks=range(self.world_size),
)
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.world,
(self.world_size, 1),
)
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.data,
(self.sequence_data_parallel, sequence_data_stride),
(self.batch_data_parallel, batch_data_stride),
)
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.pipeline, (self.pipeline_parallel, pipeline_stride)
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.data,
size=self.data_parallel,
rank=self.data_rank,
global_ranks=self._get_global_ranks(self.data_parallel, data_stride),
)
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.tensor, (self.tensor_parallel, tensor_stride)
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.pipeline,
size=self.pipeline_parallel,
rank=self.pipeline_rank,
global_ranks=self._get_global_ranks(self.pipeline_parallel, pipeline_stride),
)
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.sequence_data,
(self.sequence_data_parallel, sequence_data_stride),
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor,
size=self.tensor_parallel,
rank=self.tensor_rank,
global_ranks=self._get_global_ranks(self.tensor_parallel, 1),
)
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.batch_data, (self.batch_data_parallel, batch_data_stride)
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.sequence_data,
size=self.sequence_data_parallel,
rank=self.sequence_data_rank,
global_ranks=self._get_global_ranks(self.sequence_data_parallel, data_stride),
)
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.tensor_and_sequence_data,
(self.tensor_parallel, tensor_stride),
(self.sequence_data_parallel, sequence_data_stride),
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.batch_data,
size=self.batch_data_parallel,
rank=self.batch_data_rank,
global_ranks=self._get_global_ranks(
self.batch_data_parallel, data_stride * self.sequence_data_parallel
),
)
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.tensor_and_data,
(self.tensor_parallel, tensor_stride),
(self.sequence_data_parallel, sequence_data_stride),
(self.batch_data_parallel, batch_data_stride),
)
# Global ranks wrong with pipeline first, so we hide the dims as a safety check.
if not self.pipeline_first:
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_sequence_data,
size=self.sequence_data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
global_ranks=self._get_global_ranks(self.sequence_data_parallel * self.tensor_parallel, 1),
)
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_data,
size=self.data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.data_rank * self.tensor_parallel,
global_ranks=self._get_global_ranks(self.data_parallel * self.tensor_parallel, 1),
)
)

super()._validate()
self._add_distributed_dim_from_sizes_and_strides(
DistributedDimNames.model_and_sequence_data,
(self.tensor_parallel, tensor_stride),
(
(self.pipeline_parallel, pipeline_stride)
if self.pipeline_first
else (self.sequence_data_parallel, sequence_data_stride)
),
(
(self.sequence_data_parallel, sequence_data_stride)
if self.pipeline_first
else (self.pipeline_parallel, pipeline_stride)
),
)

super()._validate()
if self.reference_config is not None:
self.compare(self.reference_config, ValueError)
Assert.in_range(self.rank, 0, self.world_size)
Assert.in_range(self.local_rank, 0, self.local_world_size)

def _get_global_ranks(self, size: int, stride: int) -> range:
start = self.rank // (size * stride) * size * stride + self.rank % stride
return range(start, start + size * stride, stride)
def _add_distributed_dim_from_sizes_and_strides(self, name: str, *sizes_and_strides: tuple[int, int]) -> None:
self._add_distributed_dim(DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides))

def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim)
Expand Down
18 changes: 9 additions & 9 deletions fast_llm/engine/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
local_world_size: int | None = None,
timeout: float = 60,
use_cpu: bool = False,
init_method: str = "env://",
backend: DistributedBackend = DistributedBackend.nccl,
):

Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(
# TODO: Allow other init methods?
self.store, _, _ = next(
torch.distributed.rendezvous(
"env://",
init_method,
self._rank,
self._world_size,
timeout=datetime.timedelta(seconds=timeout),
Expand Down Expand Up @@ -180,14 +181,13 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor])
self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data])
self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data])
# Global ranks wrong with pipeline first, so we hide the dims as a safety check.
if not self._config.pipeline_first:
self.tensor_and_sequence_data_group = self.add_group(
self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data]
)
self.tensor_and_data_group = self.add_group(
self._config.distributed_dims[DistributedDimNames.tensor_and_data]
)
self.tensor_and_sequence_data_group = self.add_group(
self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data]
)
self.tensor_and_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor_and_data])
self.model_and_sequence_data_group = self.add_group(
self._config.distributed_dims[DistributedDimNames.model_and_sequence_data]
)

self._config.log_first_rank(f"Setting random seeds...")

Expand Down
6 changes: 4 additions & 2 deletions fast_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def rms_close_relative(x, y, threshold, min_threshold=0, *, msg=None):
)

@staticmethod
def all_equal(x, *args):
def all_equal(x, *args, msg=None):
import torch

# Make it work for lists and numpy arrays.
Expand All @@ -181,7 +181,9 @@ def all_equal(x, *args):
index = None if x.numel() == 1 else torch.where(neq) # noqa
raise AssertionError(
f"Tensors have {index[0].numel()} different entries out of "
f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}"
f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + ""
if msg is None
else f"| {msg}"
)

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
)

from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip
from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip
from tests.utils.utils import result_path # isort: skip
from tests.utils.subtest import format_resource_report, report_subtest, run_parallel_script # isort: skip

# Import all dynamic classes.
import fast_llm.cli # isort: skip
Expand Down
93 changes: 0 additions & 93 deletions tests/models/distributed_test_checkpoint.py

This file was deleted.

Loading