Skip to content

Commit 91b28fa

Browse files
authored
Distributed and testing tweaks (#440)
1 parent 200581d commit 91b28fa

File tree

16 files changed

+481
-467
lines changed

16 files changed

+481
-467
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
pip install pybind11
3333
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
3434
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
35-
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]"
35+
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
3636
- name: Run tests
3737
run: pytest -v -ra .
3838

.github/workflows/docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
pip install pybind11
3535
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
3636
MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \
37-
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]"
37+
pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]"
3838
- name: Build the documentation
3939
run: mkdocs build
4040

fast_llm/data/data/gpt/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
from fast_llm.core.distributed import safe_barrier
1010
from fast_llm.data.data.abstract import Data
1111
from fast_llm.data.data.gpt.config import GPTDataConfig
12+
from fast_llm.data.data_loader import SampledDatasetIterator
1213
from fast_llm.data.dataset.abstract import SampledDataset
1314
from fast_llm.data.dataset.config import SamplingParameters
1415
from fast_llm.data.dataset.gpt.config import GPTSamplingData
1516
from fast_llm.data.dataset.monitor import DatasetMonitor
16-
from fast_llm.data.iterator import SampledDatasetIterator
1717
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
1818
from fast_llm.data.sample.language_model import LanguageModelBatch
1919
from fast_llm.engine.config_utils.run import log_main_rank

fast_llm/engine/distributed/config.py

Lines changed: 73 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,31 @@ def setup(self, group: "ProcessGroup|None"):
9797
def check_ranks_in_range(self, start, stop):
9898
check_ranks_in_range(self.global_ranks, start, stop)
9999

100+
@classmethod
101+
def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: tuple[int, int]) -> typing.Self:
102+
start = global_rank
103+
rank = 0
104+
world_size = 1
105+
for i, (size, stride) in enumerate(sizes_and_strides):
106+
if i > 0:
107+
Assert.multiple(stride, sizes_and_strides[i - 1][1])
108+
rank_ = global_rank // stride % size
109+
start -= rank_ * stride
110+
rank += world_size * rank_
111+
world_size *= size
112+
global_ranks = [start]
113+
for size, stride in sizes_and_strides:
114+
if size == 1:
115+
continue
116+
if len(global_ranks) == 1:
117+
global_ranks = range(start, start + size * stride, stride)
118+
elif isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start:
119+
global_ranks = range(start, start + size * stride, global_ranks.step)
120+
else:
121+
global_ranks = [rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks]
122+
Assert.eq(len(global_ranks), world_size)
123+
return DistributedDim(name=name, size=world_size, rank=rank, global_ranks=global_ranks)
124+
100125

101126
def check_ranks_in_range(global_ranks, start, stop):
102127
Assert.geq(min(global_ranks), start)
@@ -112,6 +137,7 @@ class DistributedDimNames:
112137
sequence_data = "sequence_data"
113138
batch_data = "batch_data"
114139
tensor_and_sequence_data = "tensor_and_sequence_data"
140+
model_and_sequence_data = "model_and_sequence_data"
115141
tensor_and_data = "tensor_and_data"
116142

117143

@@ -300,88 +326,68 @@ def _validate(self) -> None:
300326
else:
301327
self.distributed_dims = {}
302328

303-
data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1)
329+
tensor_stride = 1
330+
sequence_data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1)
331+
batch_data_stride = sequence_data_stride * self.sequence_data_parallel
304332
pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel)
305333

306-
self._add_distributed_dim(
307-
DistributedDim(
308-
name=DistributedDimNames.world,
309-
size=self.world_size,
310-
rank=self.rank,
311-
global_ranks=range(self.world_size),
312-
)
334+
self._add_distributed_dim_from_sizes_and_strides(
335+
DistributedDimNames.world,
336+
(self.world_size, 1),
337+
)
338+
self._add_distributed_dim_from_sizes_and_strides(
339+
DistributedDimNames.data,
340+
(self.sequence_data_parallel, sequence_data_stride),
341+
(self.batch_data_parallel, batch_data_stride),
342+
)
343+
self._add_distributed_dim_from_sizes_and_strides(
344+
DistributedDimNames.pipeline, (self.pipeline_parallel, pipeline_stride)
313345
)
314-
self._add_distributed_dim(
315-
DistributedDim(
316-
name=DistributedDimNames.data,
317-
size=self.data_parallel,
318-
rank=self.data_rank,
319-
global_ranks=self._get_global_ranks(self.data_parallel, data_stride),
320-
)
346+
self._add_distributed_dim_from_sizes_and_strides(
347+
DistributedDimNames.tensor, (self.tensor_parallel, tensor_stride)
321348
)
322-
self._add_distributed_dim(
323-
DistributedDim(
324-
name=DistributedDimNames.pipeline,
325-
size=self.pipeline_parallel,
326-
rank=self.pipeline_rank,
327-
global_ranks=self._get_global_ranks(self.pipeline_parallel, pipeline_stride),
328-
)
349+
self._add_distributed_dim_from_sizes_and_strides(
350+
DistributedDimNames.sequence_data,
351+
(self.sequence_data_parallel, sequence_data_stride),
329352
)
330-
self._add_distributed_dim(
331-
DistributedDim(
332-
name=DistributedDimNames.tensor,
333-
size=self.tensor_parallel,
334-
rank=self.tensor_rank,
335-
global_ranks=self._get_global_ranks(self.tensor_parallel, 1),
336-
)
353+
self._add_distributed_dim_from_sizes_and_strides(
354+
DistributedDimNames.batch_data, (self.batch_data_parallel, batch_data_stride)
337355
)
338-
self._add_distributed_dim(
339-
DistributedDim(
340-
name=DistributedDimNames.sequence_data,
341-
size=self.sequence_data_parallel,
342-
rank=self.sequence_data_rank,
343-
global_ranks=self._get_global_ranks(self.sequence_data_parallel, data_stride),
344-
)
356+
self._add_distributed_dim_from_sizes_and_strides(
357+
DistributedDimNames.tensor_and_sequence_data,
358+
(self.tensor_parallel, tensor_stride),
359+
(self.sequence_data_parallel, sequence_data_stride),
345360
)
346-
self._add_distributed_dim(
347-
DistributedDim(
348-
name=DistributedDimNames.batch_data,
349-
size=self.batch_data_parallel,
350-
rank=self.batch_data_rank,
351-
global_ranks=self._get_global_ranks(
352-
self.batch_data_parallel, data_stride * self.sequence_data_parallel
353-
),
354-
)
361+
self._add_distributed_dim_from_sizes_and_strides(
362+
DistributedDimNames.tensor_and_data,
363+
(self.tensor_parallel, tensor_stride),
364+
(self.sequence_data_parallel, sequence_data_stride),
365+
(self.batch_data_parallel, batch_data_stride),
355366
)
356-
# Global ranks wrong with pipeline first, so we hide the dims as a safety check.
357-
if not self.pipeline_first:
358-
self._add_distributed_dim(
359-
DistributedDim(
360-
name=DistributedDimNames.tensor_and_sequence_data,
361-
size=self.sequence_data_parallel * self.tensor_parallel,
362-
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
363-
global_ranks=self._get_global_ranks(self.sequence_data_parallel * self.tensor_parallel, 1),
364-
)
365-
)
366-
self._add_distributed_dim(
367-
DistributedDim(
368-
name=DistributedDimNames.tensor_and_data,
369-
size=self.data_parallel * self.tensor_parallel,
370-
rank=self.tensor_rank + self.data_rank * self.tensor_parallel,
371-
global_ranks=self._get_global_ranks(self.data_parallel * self.tensor_parallel, 1),
372-
)
373-
)
374367

375-
super()._validate()
368+
self._add_distributed_dim_from_sizes_and_strides(
369+
DistributedDimNames.model_and_sequence_data,
370+
(self.tensor_parallel, tensor_stride),
371+
(
372+
(self.pipeline_parallel, pipeline_stride)
373+
if self.pipeline_first
374+
else (self.sequence_data_parallel, sequence_data_stride)
375+
),
376+
(
377+
(self.sequence_data_parallel, sequence_data_stride)
378+
if self.pipeline_first
379+
else (self.pipeline_parallel, pipeline_stride)
380+
),
381+
)
376382

383+
super()._validate()
377384
if self.reference_config is not None:
378385
self.compare(self.reference_config, ValueError)
379386
Assert.in_range(self.rank, 0, self.world_size)
380387
Assert.in_range(self.local_rank, 0, self.local_world_size)
381388

382-
def _get_global_ranks(self, size: int, stride: int) -> range:
383-
start = self.rank // (size * stride) * size * stride + self.rank % stride
384-
return range(start, start + size * stride, stride)
389+
def _add_distributed_dim_from_sizes_and_strides(self, name: str, *sizes_and_strides: tuple[int, int]) -> None:
390+
self._add_distributed_dim(DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides))
385391

386392
def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
387393
Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim)

fast_llm/engine/distributed/distributed.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
local_world_size: int | None = None,
2929
timeout: float = 60,
3030
use_cpu: bool = False,
31+
init_method: str = "env://",
3132
backend: DistributedBackend = DistributedBackend.nccl,
3233
):
3334

@@ -58,7 +59,7 @@ def __init__(
5859
# TODO: Allow other init methods?
5960
self.store, _, _ = next(
6061
torch.distributed.rendezvous(
61-
"env://",
62+
init_method,
6263
self._rank,
6364
self._world_size,
6465
timeout=datetime.timedelta(seconds=timeout),
@@ -180,14 +181,13 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
180181
self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor])
181182
self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data])
182183
self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data])
183-
# Global ranks wrong with pipeline first, so we hide the dims as a safety check.
184-
if not self._config.pipeline_first:
185-
self.tensor_and_sequence_data_group = self.add_group(
186-
self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data]
187-
)
188-
self.tensor_and_data_group = self.add_group(
189-
self._config.distributed_dims[DistributedDimNames.tensor_and_data]
190-
)
184+
self.tensor_and_sequence_data_group = self.add_group(
185+
self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data]
186+
)
187+
self.tensor_and_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor_and_data])
188+
self.model_and_sequence_data_group = self.add_group(
189+
self._config.distributed_dims[DistributedDimNames.model_and_sequence_data]
190+
)
191191

192192
self._config.log_first_rank(f"Setting random seeds...")
193193

fast_llm/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def rms_close_relative(x, y, threshold, min_threshold=0, *, msg=None):
167167
)
168168

169169
@staticmethod
170-
def all_equal(x, *args):
170+
def all_equal(x, *args, msg=None):
171171
import torch
172172

173173
# Make it work for lists and numpy arrays.
@@ -181,7 +181,9 @@ def all_equal(x, *args):
181181
index = None if x.numel() == 1 else torch.where(neq) # noqa
182182
raise AssertionError(
183183
f"Tensors have {index[0].numel()} different entries out of "
184-
f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}"
184+
f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + ""
185+
if msg is None
186+
else f"| {msg}"
185187
)
186188

187189
@staticmethod

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
)
3333

3434
from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip
35-
from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip
35+
from tests.utils.utils import result_path # isort: skip
36+
from tests.utils.subtest import format_resource_report, report_subtest, run_parallel_script # isort: skip
3637

3738
# Import all dynamic classes.
3839
import fast_llm.cli # isort: skip

tests/models/distributed_test_checkpoint.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

0 commit comments

Comments
 (0)