diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 5083c5121..ba5be883a 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,13 +1,14 @@ import logging +import typing -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SampledDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig -from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.sample.language_model import LanguageModelSample logger = logging.getLogger(__name__) @@ -22,12 +23,11 @@ class GPTDataConfig(DataConfig): _abstract = False # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, SampledDatasetConfig[GPTSample]] = Field( + datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 2a18afd50..de47ef761 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -2,7 +2,6 @@ import pathlib import typing import warnings -from functools import partial import torch import torch.utils.data @@ -14,7 +13,7 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator -from fast_llm.data.sample.gpt import GPTBatch, GPTSample +from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -24,32 +23,9 @@ logger = logging.getLogger(__name__) -def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: - stacked_spans = None - sequence_lengths = None - stacked_chosen_spans = None - stacked_rejected_spans = None - if sampling_parameters.use_loss_masking_spans: - stacked_spans = [sample.loss_masking_spans for sample in batch] - if sampling_parameters.use_preference_loss_spans: - stacked_chosen_spans = [sample.chosen_span for sample in batch] - stacked_rejected_spans = [sample.rejected_span for sample in batch] - if not sampling_parameters.cross_document_attention: - sequence_lengths = [sample.sequence_lengths for sample in batch] - return GPTBatch( - token_ids=torch.stack([sample.token_ids for sample in batch]), - loss_masking_spans=stacked_spans, - sequence_lengths=sequence_lengths, - chosen_spans=stacked_chosen_spans, - rejected_spans=stacked_rejected_spans, - ) - - class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ A global class for all dataset needs, including loading, splitting, sampling and iteration. - Currently hard-coded to a GPT dataset. - TODO: Separate generic and GPT classes. """ _datasets: dict[str, SampledDataset] @@ -124,7 +100,7 @@ def get_iterator( num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[GPTBatch]: + ) -> typing.Iterator[LanguageModelBatch]: assert self._is_setup # Some dataset names may come from phases and are capitalized, @@ -149,10 +125,7 @@ def get_iterator( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, - collate_fn=partial( - gpt_data_collate_fn, - sampling_parameters=sampling_parameters, - ), + collate_fn=LanguageModelBatch.from_samples, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) ) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7a8d3567d..20e40b66e 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -1,4 +1,5 @@ import dataclasses +import enum import functools import itertools import math @@ -15,6 +16,17 @@ from fast_llm.engine.distributed.distributed import Distributed +class ShufflingType(str, enum.Enum): + # Shuffle all epochs together. Not extendable. + full = "full" + # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. + epoch = "epoch" + # Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones. + skip_first_epoch = "skip_first_epoch" + # Disable shuffling entirely. + disabled = "disabled" + + @config_class() class SamplingConfig(Config): """ @@ -26,6 +38,18 @@ class SamplingConfig(Config): desc="Seed for random sampling.", hint=FieldHint.feature, ) + gpu: bool = Field( + default=True, + desc="Enable fast sampling on GPU." + " Note that random sampling works differently on GPU," + " so the sample won't match the CPU equivalent.", + hint=FieldHint.feature, + ) + shuffle: ShufflingType = Field( + default=ShufflingType.epoch, + desc="Shuffling strategy.", + hint=FieldHint.feature, + ) @dataclasses.dataclass(kw_only=True) @@ -34,7 +58,12 @@ class SamplingParameters: Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ + sequence_length: int num_samples: int + truncate_documents: bool = True + # How many extra tokens to add to the sequence length. + # This is used to provide labels even for the last tokens in the sequence. + extra_tokens: int = 1 @dataclasses.dataclass(kw_only=True) @@ -118,10 +147,7 @@ class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[Sampl def build(self) -> "ConcatenatedDataset": from fast_llm.data.dataset.indexed import ConcatenatedDataset - return self._build(ConcatenatedDataset) - - def _build[T: ConcatenatedDataset](self, cls: type[T]) -> T: - return cls(self.name, [dataset.build() for dataset in self.datasets]) + return ConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets]) @config_class(dynamic_type={SampledDatasetConfig: "slice"}) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 36412b6ce..15f54ec80 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -1,5 +1,4 @@ import dataclasses -import enum import pathlib import time import typing @@ -13,64 +12,27 @@ IndexedDatasetConfig, SamplableDatasetConfig, SampledDatasetConfig, - SamplingConfig, SamplingData, SamplingParameters, ) -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset -class ShufflingType(str, enum.Enum): - # Shuffle all epochs together. Not extendable. - full = "full" - # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. - epoch = "epoch" - # Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones. - skip_first_epoch = "skip_first_epoch" - # Disable shuffling entirely. - disabled = "disabled" - - -@config_class() -class GPTSamplingConfig(SamplingConfig): - """ - A dataset-dependent configuration for sampling. - """ - - gpu: bool = Field( - default=True, - desc="Enable fast sampling on GPU." - " Note that random sampling works differently on GPU," - " so the sample won't match the CPU equivalent.", - hint=FieldHint.feature, - ) - shuffle: ShufflingType = Field( - default=ShufflingType.epoch, - desc="Shuffling strategy.", - hint=FieldHint.feature, - ) - - @dataclasses.dataclass(kw_only=True) class GPTSamplingParameters(SamplingParameters): """ Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ - sequence_length: int vocab_size: int use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False - cross_document_attention: bool = True - truncate_documents: bool = True - # How many extra tokens to add to the sequence length. - # This is used to provide labels even for the last tokens in the sequence. - extra_tokens: int = 1 @dataclasses.dataclass(kw_only=True) @@ -80,12 +42,11 @@ class GPTSamplingData(SamplingData): usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`. """ - config: GPTSamplingConfig parameters: GPTSamplingParameters @config_class(dynamic_type={SampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): +class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -93,14 +54,14 @@ class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[Sampl hint=FieldHint.core, ) - def build(self) -> "GPTRandomDataset": + def build(self) -> "GPTRandomDataset[SampleType]": from fast_llm.data.dataset.gpt.random import GPTRandomDataset - return GPTRandomDataset(self.name) + return GPTRandomDataset[SampleType](self.name) @config_class(dynamic_type={SampledDatasetConfig: "memmap"}) -class GPTMemmapDatasetConfig[SampleType: GPTSample](IndexedDatasetConfig[SampleType]): +class GPTMemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -118,14 +79,16 @@ class GPTMemmapDatasetConfig[SampleType: GPTSample](IndexedDatasetConfig[SampleT hint=FieldHint.optional, ) - def build(self) -> "GPTMemmapDataset": + def build(self) -> "GPTMemmapDataset[SampleType]": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset[SampleType]( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens + ) @config_class(dynamic_type={SampledDatasetConfig: "file"}) -class GPTDatasetFromFileConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): +class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -235,14 +198,14 @@ class FimConfig(Config): @config_class(dynamic_type={SampledDatasetConfig: "fim"}) -class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType], FimConfig): +class GPTFimSampledDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType], FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - dataset: SampledDatasetConfig = Field( + dataset: SampledDatasetConfig[SampleType] = Field( default=None, desc="The dataset to wrap with fim.", hint=FieldHint.core, @@ -250,15 +213,15 @@ class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[Sam def build_and_sample( self, - sampling: SamplingData, - ) -> SampledDataset: + sampling: GPTSamplingData, + ) -> "GPTFimDataset[SampleType]": from fast_llm.data.dataset.gpt.fim import GPTFimDataset - return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling) + return GPTFimDataset[SampleType](self, self.dataset.build_and_sample(sampling), sampling) @config_class(dynamic_type={SampledDatasetConfig: "test_slow"}) -class GPTTestSlowDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType]): +class GPTTestSlowDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. """ diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 175a0e549..1fde74530 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -3,11 +3,12 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.distributed.config import MAX_SEED -class GPTFimDataset[SampleType: GPTSample](SampledDataset[SampleType]): +class GPTFimDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py @@ -43,10 +44,13 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> SampleType: # TODO: Use torch methods to avoid back and forth. - return GPTSample( - torch.from_numpy( - self._fim( - self._dataset[index].token_ids.numpy(), np.random.RandomState(seed=(self._seed + index) % MAX_SEED) + return LanguageModelSample( + TokenSample( + torch.from_numpy( + self._fim( + self._dataset[index].tokens.tokens.numpy(), + np.random.RandomState(seed=(self._seed + index) % MAX_SEED), + ) ) ) ) @@ -79,19 +83,19 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:], np_rng) new_samples.append(permuted) - sample = np.concatenate(new_samples) + fim_sample = np.concatenate(new_samples) else: - sample = self._fim_split_and_permute_sequence(sample, np_rng) + fim_sample = self._fim_split_and_permute_sequence(sample, np_rng) # Truncate or pad sequence to max-length - diff = sample.shape[0] - sample_len + diff = fim_sample.shape[0] - sample_len if diff > 0: # too long - sample = sample[:sample_len] + fim_sample = fim_sample[:sample_len] elif diff < 0: # too short - sample = np.concatenate([sample, np.full((-1 * diff), self._pad_tok_id)]) + fim_sample = np.concatenate([fim_sample, np.full((-1 * diff), self._pad_tok_id)]) # noqa - assert sample.shape[0] == sample_len - return sample + assert fim_sample.shape[0] == sample_len + return fim_sample.astype(sample.dtype) def _fim_split_and_permute_sequence(self, sequence: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: """ @@ -164,9 +168,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=sequence.dtype) + middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=sequence.dtype) + suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=sequence.dtype) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c78805380..06d8d7acc 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -8,12 +8,14 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div -class GPTMemmapDataset[SampleType: GPTSample](IndexedDataset[SampleType]): +class GPTMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing @@ -47,7 +49,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None if self._version >= 3: self._has_preference_spans = struct.unpack(" SampleType: if end is None: end = self.get_document_size(index) - token_ids = np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=end - begin, - offset=self._pointers[index] + begin * np.dtype(self._dtype).itemsize, + sample_size = self._document_sizes[index].item() + assert 0 <= begin <= end <= sample_size, (0, begin, end, sample_size) + token_ids = ( + torch.frombuffer( + self._bin_buffer, + dtype=self._dtype, + count=end - begin, + offset=self._pointers[index].item() + begin * self._dtype.itemsize, + ) + if end > begin + else torch.empty(0, dtype=self._dtype) ) - sample_spans = None + if not self._dtype.is_signed: + # Needed because torch doesn't yet support type promotion between signed and unsigned types. TODO: Remove when supported. + token_ids = token_ids.to(torch.int64) if parameters is not None and parameters.use_loss_masking_spans: assert self._spans is not None - sample_spans = self._spans[index] - - # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[(sample_spans[:, 0] < begin + len(token_ids)) & (sample_spans[:, 1] >= begin)] - - # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], begin) - begin # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], begin + len(token_ids) - 1) - begin - sample_spans = torch.from_numpy(sample_spans) - - chosen_span = None - rejected_span = None + # TODO: ====== Store in range format (begin, end) ====== + sample_spans = RangeSample( + [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size + ).crop(begin, end) + else: + sample_spans = None if parameters is not None and parameters.use_preference_loss_spans: if not self._has_preference_spans: @@ -178,34 +182,23 @@ def get_document( raise ValueError("Failed to read chosen spans from memmap dataset.") elif self._has_preference_spans and self._rejected_spans is None: raise ValueError("Failed to read rejected spans from memmap dataset.") - else: - chosen_span = self._chosen_spans[index] - - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < begin + len(token_ids)) & (chosen_span[1] >= begin)][0] - - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], begin) - begin # offset - chosen_span[1] = np.minimum(chosen_span[1], begin + len(token_ids) - 1) - begin - chosen_span = torch.from_numpy(chosen_span) - - rejected_span = self._rejected_spans[index] - - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[ - (rejected_span[0] < begin + len(token_ids)) & (rejected_span[1] >= begin) - ][0] - - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], begin) - begin # offset - rejected_span[1] = np.minimum(rejected_span[1], begin + len(token_ids) - 1) - begin - rejected_span = torch.from_numpy(rejected_span) + # TODO: ====== Store in range format ====== + chosen_spans = RangeSample( + [(self._chosen_spans[index][0].item(), self._chosen_spans[index][1].item() + 1)], + sample_size, + ).crop(begin, end) + rejected_spans = RangeSample( + [(self._rejected_spans[index][0].item(), self._rejected_spans[index][1].item() + 1)], + sample_size, + ).crop(begin, end) + else: + chosen_spans = rejected_spans = None - return GPTSample( - token_ids=torch.from_numpy(token_ids), + return LanguageModelSample( + tokens=TokenSample(token_ids), loss_masking_spans=sample_spans, - chosen_span=chosen_span, - rejected_span=rejected_span, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, ) @property @@ -231,7 +224,11 @@ def get_document_size(self, index: int) -> int: return self._document_sizes[index].item() @classmethod - def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): + def write_dataset( + cls, + prefix: pathlib.Path | str, + documents: typing.Iterable[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]], + ) -> None: # Initialize metadata dtype = None num_documents = 0 @@ -249,29 +246,29 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write the binary data file (.bin) lazily with prefix.with_suffix(".bin").open("wb") as bin_stream: - for document in documents: + for token_ids, loss_masking_spans, chosen_span, rejected_span in documents: # Infer dtype from the first document if dtype is None: - dtype = document.token_ids.dtype + dtype = token_ids.dtype assert dtype is not None, "Document dtype could not be inferred from the data." # Ensure all documents have the same dtype - assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." + assert token_ids.dtype == dtype, f"Expected dtype {dtype}, got {token_ids.dtype}." # Write document to binary file - bin_stream.write(document.token_ids.numpy().tobytes(order="C")) + bin_stream.write(token_ids.numpy().tobytes(order="C")) # Update metadata - doc_length = len(document.token_ids) + doc_length = len(token_ids) lengths.append(doc_length) pointers.append(offset) - if document.loss_masking_spans is not None: - num_spans.append(len(document.loss_masking_spans)) - spans.append(document.loss_masking_spans) - if document.chosen_span is not None: - chosen_spans.append(document.chosen_span) - if document.rejected_span is not None: - rejected_spans.append(document.rejected_span) + if loss_masking_spans is not None: + num_spans.append(len(loss_masking_spans)) + spans.append(loss_masking_spans) + if chosen_span is not None: + chosen_spans.append(chosen_span) + if rejected_span is not None: + rejected_spans.append(rejected_span) offset += doc_length * dtype.itemsize num_documents += 1 diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index c12e4adcc..463c5a7d6 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -3,10 +3,12 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type -class GPTRandomDataset(SamplableDataset): +class GPTRandomDataset[SampleType: LanguageModelSample](SamplableDataset[SampleType]): """ A dummy dataset that always returns the same random sample, for debugging purposes. """ @@ -22,23 +24,30 @@ def name(self) -> str: return self._name -class GPTRandomSampledDataset[SampleType: GPTSample](SampledDataset[SampleType]): +class GPTRandomSampledDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed - self._sequence_length = sampling.parameters.sequence_length - self._vocab_size = sampling.parameters.vocab_size - self._num_samples = sampling.parameters.num_samples + self._parameters = sampling.parameters + # TODO: Support? + assert not self._parameters.use_loss_masking_spans + assert not self._parameters.use_preference_loss_spans + self._dtype = get_unsigned_integer_type(self._parameters.vocab_size).torch def __len__(self) -> int: - return self._num_samples + return self._parameters.num_samples def __getitem__(self, index: int) -> SampleType: - return GPTSample( - torch.from_numpy( - np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( - 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 - ) + # TODO: Sample in self._dtype (breaking) + return LanguageModelSample( + TokenSample( + torch.from_numpy( + np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( + 0, + self._parameters.vocab_size, + size=(self._parameters.sequence_length + self._parameters.extra_tokens,), + ) + ).to(self._dtype), ) ) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 238e99bca..46a518cd0 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -9,10 +9,9 @@ import yaml from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType +from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.abstract import Sample -from fast_llm.data.sample.gpt import GPTSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -69,16 +68,14 @@ def _lazy_load(self): class SampledIndexedDataset[SampleType: Sample](SampledDataset[SampleType]): """ - A sampled GPT dataset. + A sampled dataset. """ def __init__( self, indexed_dataset: IndexedDataset[SampleType], - # TODO: ====== Remove gpt-specific stuff ====== - sampling: GPTSamplingData, + sampling: SamplingData, ): - assert isinstance(sampling, GPTSamplingData) self._indexed_dataset = indexed_dataset self._config = sampling.config self._parameters = sampling.parameters @@ -108,22 +105,15 @@ def __init__( self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) self._yaml_path = base_path.with_suffix(".yaml") - # keep document sizes and len filtered docs for preference loss masking - if self._parameters.use_preference_loss_spans: - self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) - self._doc_length_filtered_indicies = MemmapArray( - base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") - ) - # Sample or validate the dataset of a given rank. if sampling.distributed.config.rank == sampling.get_next_rank(): self._sample() # No barrier yet to allow running in parallel. - # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. + # There needs to be one before calling `__getitem__`, normally handled through `Data`. def _sample(self) -> None: """ - Create a `GPTSampledDataset` with the requested parameters. + Create a `SampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. document_sizes = self._indexed_dataset.get_document_sizes().to(self._device) @@ -152,10 +142,7 @@ def _sample(self) -> None: # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. - if self._parameters.use_preference_loss_spans: - documents_per_epoch = (~long_docs_filter).sum().item() - num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) - elif self._truncate_documents: + if self._truncate_documents: num_epochs = math.ceil( (self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens) / tokens_per_epoch @@ -259,24 +246,6 @@ def _sample(self) -> None: else: raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - if self._parameters.use_preference_loss_spans: - yaml_data["unshuffled_tokens"] = 0 # not used, ignore - - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # apply shuffling on doc_length_filtered_indicies - if shuffled_epochs > 0: - self._document_shuffling.save( - document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) - ) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - return - # To get a sample on the fly we need to know where it begins, # and this is a non-trivial information because the documents have variable length. # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. @@ -372,42 +341,10 @@ def __getitem__(self, index: int) -> SampleType: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. - The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). + The returned sample is ready to be concatenated, then fed to a `Model`. """ self._lazy_load() - if self._parameters.use_preference_loss_spans: - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - sample = self._indexed_dataset.get_document( - document_index.item(), - begin=0, - end=self._document_sizes[document_index].item(), - parameters=self._parameters, - ) - - chosen_span_end = sample.chosen_span[1] + 1 - sequence_lengths = [ - chosen_span_end, - len(sample.token_ids) - chosen_span_end, - ] - - # compute padding size - padding = np.full((self._parameters.sequence_length + 1,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - if not self._parameters.cross_document_attention: - sample.sequence_lengths = torch.tensor(sequence_lengths) - - return sample - # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample sample_length = ( @@ -432,8 +369,7 @@ def __getitem__(self, index: int) -> SampleType: token_count = token_start_array[token_start_cumsum_index] - token_ids = [] - loss_masking_spans = [] + documents: list[SampleType] = [] while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -453,8 +389,7 @@ def __getitem__(self, index: int) -> SampleType: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: - # Add padding tokens to current sample - token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + documents.append(documents[-1].get_padding(padding_size)) Assert.eq(token_count + padding_size, token_end) break else: @@ -466,45 +401,21 @@ def __getitem__(self, index: int) -> SampleType: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) - sample = self._indexed_dataset.get_document( - document_index, - begin=token_start_index_in_document, - end=token_end_index_in_document, - parameters=self._parameters, + documents.append( + self._indexed_dataset.get_document( + document_index, + begin=token_start_index_in_document, + end=token_end_index_in_document, + parameters=self._parameters, + ) ) - token_ids.append(sample.token_ids) - if self._parameters.use_loss_masking_spans: - for loss_masking_span in sample.loss_masking_spans: - span = np.clip( - loss_masking_span + token_count - token_start, - 0, - self._parameters.sequence_length + self._parameters.extra_tokens, - ) - if span[1] >= span[0]: - loss_masking_spans.append(span) # Go to the next document. document_sampling_index += 1 token_count += document_size - sequence_lengths = ( - torch.tensor([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - if not self._parameters.cross_document_attention - else None - ) - token_ids = np.concatenate(token_ids, dtype=np.int64) - loss_masking_spans = ( - torch.from_numpy(np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) - if self._parameters.use_loss_masking_spans - else None - ) - Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - - return GPTSample( - token_ids=torch.from_numpy(token_ids), - loss_masking_spans=loss_masking_spans, - sequence_lengths=sequence_lengths, - ) + # TODO: ====== Better way to get the class method? ====== + return documents[0].from_documents(documents) @property def name(self) -> str: @@ -517,13 +428,5 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if self._parameters.use_preference_loss_spans: - data["unshuffled_tokens"] = 0 # not used, ignore - elif "unshuffled_tokens" not in data: - # Backward compatibility - # TODO v0.x: Remove - assert self._truncate_documents - data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] - self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index a8ff187ae..274bbf1b0 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -24,7 +24,7 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -37,7 +37,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _data_type: DataType _text_column: str _loss_masking_spans_column: str | None - _sample_type: typing.ClassVar[type[GPTSample]] = GPTSample + _sample_type: typing.ClassVar[type[LanguageModelSample]] = LanguageModelSample def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ @@ -142,11 +142,14 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): + # TODO: Yield `LanguageModelSample` if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( + yield ( torch.tensor(item["input_ids"], dtype=self._data_type.torch), torch.tensor(item["token_spans"], dtype=torch.int32).reshape(-1, 2), + None, + None, ) elif ( "chosen_token_spans" in shard_dataset.column_names @@ -155,14 +158,20 @@ def _document_generator(): and self._config.dataset.rejected_text is not None ): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - token_ids=torch.tensor(item["input_ids"], dtype=self._data_type.torch), - chosen_span=torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), - rejected_span=torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), + yield ( + torch.tensor(item["input_ids"], dtype=self._data_type.torch), + None, + torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), + torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(torch.tensor(item["input_ids"], dtype=self._data_type.torch)) + yield ( + torch.tensor(item["input_ids"], dtype=self._data_type.torch), + None, + None, + None, + ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -241,7 +250,7 @@ def run(self) -> None: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True # Load tokenizer - self._tokenizer = Tokenizer(config=self._config.tokenizer) + self._tokenizer = self._config.tokenizer.get_tokenizer() # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 0c640b9b3..031002101 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -1,10 +1,42 @@ import abc +import typing + +if typing.TYPE_CHECKING: + import torch class Sample(abc.ABC): - pass + @classmethod + @abc.abstractmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + pass + + @abc.abstractmethod + def crop(self, begin: int, end: int) -> typing.Self: + pass + + @abc.abstractmethod + def __len__(self) -> int: + pass + + @abc.abstractmethod + def get_padding(self, size: int) -> typing.Self: + pass class Batch(abc.ABC): # TODO: Relate to `BatchConfig`? - pass + @classmethod + @abc.abstractmethod + def from_samples(cls, samples: typing.Iterable[Sample]) -> typing.Self: + pass + + @abc.abstractmethod + def to_samples(self) -> list[Sample]: + pass + + def crop(self, begin: int, end: int) -> typing.Self: + return self.from_samples(sample.crop(begin, end) for sample in self.to_samples()) + + def to_device_(self, device: "torch.device | str"): + pass diff --git a/fast_llm/data/sample/gpt.py b/fast_llm/data/sample/gpt.py deleted file mode 100644 index 4bf740462..000000000 --- a/fast_llm/data/sample/gpt.py +++ /dev/null @@ -1,25 +0,0 @@ -import dataclasses -import typing - -from fast_llm.data.sample.abstract import Batch, Sample - -if typing.TYPE_CHECKING: - import torch - - -@dataclasses.dataclass -class GPTSample(Sample): - token_ids: "torch.Tensor" - loss_masking_spans: "torch.Tensor | None" = None - chosen_span: "torch.Tensor | None" = None - rejected_span: "torch.Tensor | None" = None - sequence_lengths: "torch.Tensor | None" = None - - -@dataclasses.dataclass -class GPTBatch(Batch): - token_ids: "torch.Tensor" - loss_masking_spans: "list[torch.Tensor] | None" = None - sequence_lengths: "list[torch.Tensor] | None" = None - chosen_spans: "list[torch.Tensor] | None" = None - rejected_spans: "list[torch.Tensor] | None" = None diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py new file mode 100644 index 000000000..f30188553 --- /dev/null +++ b/fast_llm/data/sample/language_model.py @@ -0,0 +1,107 @@ +import typing + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.data.sample.range import RangeBatch, RangeSample +from fast_llm.data.sample.token import TokenBatch, TokenSample + + +class LanguageModelSample(Sample): + def __init__( + self, + tokens: TokenSample, + loss_masking_spans: RangeSample | None = None, + chosen_spans: RangeSample | None = None, + rejected_spans: RangeSample | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + TokenSample.from_documents([document.tokens for document in documents]), + _merge_optional(RangeSample.from_documents, [document.loss_masking_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def __len__(self) -> int: + return len(self.tokens) + + def get_padding(self, size: int) -> typing.Self: + return LanguageModelSample( + self.tokens.get_padding(size), + None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), + None if self.chosen_spans is None else self.chosen_spans.get_padding(size), + None if self.rejected_spans is None else self.rejected_spans.get_padding(size), + ) + + +class LanguageModelBatch(Batch): + def __init__( + self, + tokens: TokenBatch, + loss_masking_spans: RangeBatch | None = None, + chosen_spans: RangeBatch | None = None, + rejected_spans: RangeBatch | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: + return cls( + TokenBatch.from_samples([sample.tokens for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), + ) + + def to_samples(self) -> list[LanguageModelSample]: + return [ + LanguageModelSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) + for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( + self.tokens.to_samples(), + self.loss_masking_spans.to_samples(), + self.chosen_spans.to_samples(), + self.rejected_spans.to_samples(), + strict=True, + ) + ] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def to_device_(self, device: "torch.device | str"): + self.tokens.to_device_(device) + if self.loss_masking_spans is not None: + self.loss_masking_spans.to_device_(device) + if self.chosen_spans is not None: + self.chosen_spans.to_device_(device) + if self.rejected_spans is not None: + self.rejected_spans.to_device_(device) + + +def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: + return None if any(arg is None for arg in args) else fn(args) + + +def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) -> T | None: + return None if sample_or_batch is None else sample_or_batch.crop(begin, end) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py new file mode 100644 index 000000000..d121a38b6 --- /dev/null +++ b/fast_llm/data/sample/range.py @@ -0,0 +1,49 @@ +import typing + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.utils import get_unique + + +class RangeSample(Sample): + """ + A reusable component holding a set of ranges in a sample. + """ + + def __init__(self, ranges: list[tuple[int, int]], sample_size: int): + self.ranges = ranges + self.sample_size = sample_size + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + document: RangeSample + ranges = [] + sample_size = 0 + for document in documents: + for begin, end in document.ranges: + ranges.extend((begin + sample_size, end + sample_size)) + sample_size += document.sample_size + return cls(ranges, sample_size) + + def crop(self, begin: int, end: int) -> typing.Self: + sample_size = end - begin + cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, sample_size)) for begin_, end_ in self.ranges) + return self.__class__([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + + def __len__(self) -> int: + return self.sample_size + + def get_padding(self, size: int) -> typing.Self: + return RangeSample([], size) + + +class RangeBatch(Batch): + def __init__(self, ranges: list[list[tuple[int, int]]], sample_size: int): + self.sample_size = sample_size + self.ranges = ranges + + @classmethod + def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: + return cls([sample.ranges for sample in samples], get_unique(sample.sample_size for sample in samples)) + + def to_samples(self) -> list[RangeSample]: + return [RangeSample(sample_ranges, self.sample_size) for sample_ranges in self.ranges] diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py new file mode 100644 index 000000000..62d1c0e67 --- /dev/null +++ b/fast_llm/data/sample/token.py @@ -0,0 +1,75 @@ +import typing + +import torch + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.utils import Assert + + +class TokenSample(Sample): + def __init__(self, tokens: torch.Tensor, lengths: list[int] | None = None): + self.tokens = tokens + # Length of each document in the sample. TODO: Use cumsums instead? + if lengths is None: + lengths = [len(tokens)] + else: + Assert.eq(sum(lengths), len(tokens)) + self.lengths = lengths + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + torch.cat([document.tokens for document in documents]), + sum((document.lengths for document in documents), []), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + sample_size = end - begin + if self.lengths == [len(self.tokens)]: + # Shortcut for the frequent case of a single document. + lengths = [sample_size] + else: + begin_ = 0 + lengths = [] + for length in self.lengths: + end_ = begin_ + length + cropped_length = min(end_, end) - max(begin_, begin) + if cropped_length > 0: + lengths.append(cropped_length) + if end_ > end: + break + begin_ = end_ + return self.__class__(self.tokens[begin:end], lengths) + + def __len__(self) -> int: + return len(self.tokens) + + def get_padding(self, size: int) -> typing.Self: + return TokenSample(torch.full([size], -100, dtype=self.tokens.dtype), [size]) + + +class TokenBatch(Batch): + def __init__(self, tokens: torch.Tensor, lengths: list[list[int]] | None) -> None: + self.tokens = tokens + if lengths is None: + lengths = [[tokens.size(1)]] * tokens.size(0) + self.lengths = lengths + + @classmethod + def from_samples(cls, samples: typing.Iterable[TokenSample]) -> typing.Self: + return cls( + torch.stack([sample.tokens for sample in samples]), + [sample.lengths for sample in samples], + ) + + def to_samples(self) -> list[TokenSample]: + return [TokenSample(tokens, lengths) for tokens, lengths in zip(self.tokens, self.lengths, strict=True)] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens[:, begin:end], [sample.crop(begin, end).lengths for sample in self.to_samples()] + ) + + def to_device_(self, device: "torch.device | str"): + # Also standardize the dtype while we're here. + self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 96fb53321..270171755 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -120,14 +120,14 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: cls.base_model_converter_class.export_config(config.base_model), { "model_type": cls.get_huggingface_model_type(), - "architecture": cls.architecture, + "architectures": [cls.architecture], }, ) @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - Assert.eq(config["architecture"], cls.architecture) + Assert.eq(config["architectures"], [cls.architecture]) return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) def _create_weight_converters(self) -> list[WeightConverter]: diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index add121c50..1a0fed91b 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -168,6 +168,7 @@ def _set_triton_dtype_map() -> None: def get_unsigned_integer_type(max_size: int) -> DataType: + # TODO: Use uint types (recently added for torch, not enough methods supported yet) if max_size < 2**8: return DataType.uint8 elif max_size < 2**15: diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 3a70f308f..7ab0b9ff6 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -1,51 +1,25 @@ import torch -def _compute_logprobs_for_preference_spans( - logits: torch.Tensor, targets: torch.Tensor, chosen_spans: torch.Tensor, rejected_spans: torch.Tensor -): - assert torch.all(targets < logits.size(-1)), "Target out of vocab range" +def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): + # Gather log probabilities corresponding to the target tokens + return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - # gather log probabilities corresponding to the target tokens - selected_log_probs = log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - - # apply chosen mask - chosen_logp = 0 - for idx, span in enumerate(chosen_spans): - chosen_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - - # apply rejected mask - rejected_logp = 0 - for idx, span in enumerate(rejected_spans): - rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - - return chosen_logp, rejected_logp, selected_log_probs - - -def _compute_dpo_loss( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta: float, -): - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - diff_logratios = pi_logratios - ref_logratios - - losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) - return losses +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans + ) def compute_dpo_loss( logits: torch.Tensor, targets: torch.Tensor, reference_model_logits: torch.Tensor, - chosen_spans: torch.Tensor, - rejected_spans: torch.Tensor, + chosen_spans: list[list[tuple[int, int]]], + rejected_spans: list[list[tuple[int, int]]], beta: float, grad_output: float | None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -53,21 +27,18 @@ def compute_dpo_loss( logits_ = logits.float().detach().requires_grad_() reference_model_logits_ = reference_model_logits.float().detach() - policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( - logits_, targets, chosen_spans, rejected_spans - ) + policy_log_probabilities = _get_target_log_probabilities(logits_, targets) + policy_log_ratios = _get_target_log_probability_for_spans( + policy_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) - reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( - reference_model_logits_, targets, chosen_spans, rejected_spans - ) + reference_log_probabilities = _get_target_log_probabilities(reference_model_logits_, targets) + reference_log_ratios = _get_target_log_probability_for_spans( + reference_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) - losses = _compute_dpo_loss( - policy_chosen_logps=policy_chosen_logps, - policy_rejected_logps=policy_rejected_logps, - reference_chosen_logps=reference_chosen_logps, - reference_rejected_logps=reference_rejected_logps, - beta=beta, - ) + # TODO: ====== Shouldn't the sigmoid be computed independently for each document? + losses = -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)) if grad_output is None: loss = None diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 167184193..ffbe9955e 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -5,11 +5,12 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -79,7 +80,12 @@ def __init__( peft=peft, return_bias=return_bias, ) - self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) + self._implementation = self._config.implementation + if self._implementation == AttentionImplementation.auto: + if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16): + self._implementation = AttentionImplementation.flash + else: + self._implementation = AttentionImplementation.backup self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -209,8 +215,7 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - with set_generator(self._distributed.tp_generator): - attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) + attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value ) @@ -328,29 +333,10 @@ def _forward( query, key = self._rotary(query, key, kwargs) window_size = (-1, -1) if self._config.window_size is None else (self._config.window_size - 1, 0) - - if self._use_flash_attention: - assert _flash_available - with set_generator(self._distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: - out_dims = query.size() - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) - input_ = _flash_attn_varlen_func( - query, - key, - value, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), - dropout_p=self._config.dropout if self.training else 0.0, - window_size=window_size, - causal=self._config.causal, - softmax_scale=self._softmax_scale, - ).view(*out_dims) - else: + with set_generator(self._distributed.tp_generator): + if self._implementation == AttentionImplementation.flash: + assert _flash_available + if self._config.cross_document_attention: input_ = _flash_attn_func( query, key, @@ -359,17 +345,36 @@ def _forward( dropout_p=self._config.dropout if self.training else 0.0, causal=self._config.causal, softmax_scale=self._softmax_scale, + ).flatten(-2) + else: + input_ = ( + _flash_attn_varlen_func( + query.view(-1, query.size(-2), query.size(-1)), + key.view(-1, key.size(-2), key.size(-1)), + value.view(-1, value.size(-2), value.size(-1)), + cu_seqlens_q=kwargs.get(AttentionKwargs.cu_seqlens_q), + cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), + dropout_p=self._config.dropout if self.training else 0.0, + window_size=window_size, + causal=self._config.causal, + softmax_scale=self._softmax_scale, + ) + .view(query.size()) + .flatten(-2) ) - input_ = input_.flatten(-2) - else: - # TODO: Avoid the flattens. - input_ = self._attn_fused( - query.flatten(-2), - key.flatten(-2), - value.flatten(-2), - kwargs[AttentionKwargs.attention_mask], - kwargs[AttentionKwargs.attention_mask_value], - ) + elif self._implementation == AttentionImplementation.backup: + # TODO: Avoid the flattens. + input_ = self._attn_fused( + query.flatten(-2), + key.flatten(-2), + value.flatten(-2), + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], + ) + else: + raise NotImplementedError(self._implementation) if self._debug.enabled: self._debug(query, "query", self._query_dims, kwargs) @@ -413,8 +418,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c attention_compute = sequence_q * sequence_k * attn_compute_base - if (not config.hardware) or self._use_flash_attention: + if (not config.hardware) or self._implementation in AttentionImplementation.flash: # Remove non-causal part. (TODO: Support non-causal) + # TODO: Compute is overestimated without cross-document attention. attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 if self._config.window_size is not None: @@ -439,10 +445,10 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(batch, kwargs) - if not self._use_flash_attention: + if self._implementation == AttentionImplementation.backup: self._preprocess_for_backup_attention(batch, kwargs) - elif AttentionKwargs.sequence_lengths in kwargs: - self._preprocess_for_varlen(batch, kwargs) + elif self._implementation == AttentionImplementation.flash: + self._preprocess_for_flash_attention(batch, kwargs) def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: if ( @@ -471,11 +477,11 @@ def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str kwargs[AttentionKwargs.attention_mask] = self._backup_attention_mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: + if not self._config.cross_document_attention: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) - for sample_lens in sequence_lengths + for sample_lens in kwargs[AttentionKwargs.sequence_lengths] ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) @@ -485,7 +491,7 @@ def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str ) kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value - def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + def _preprocess_for_flash_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 @@ -495,7 +501,7 @@ def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.A also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if AttentionKwargs.sequence_lengths not in kwargs: + if self._config.cross_document_attention: return sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 68b6dde91..206fa6e6f 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -1,10 +1,9 @@ +import enum import logging import typing import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockKwargs @@ -32,6 +31,12 @@ class AttentionKwargs(BlockKwargs): past_key_values = "past_key_values" +class AttentionImplementation(enum.StrEnum): + auto = "auto" + flash = "flash" + backup = "backup" + + @config_class(dynamic_type={MixerConfig: "attention"}) class AttentionConfig(MixerConfig): # TODO: Make mixer class dynamic. @@ -107,6 +112,17 @@ class AttentionConfig(MixerConfig): " Under muP (if scaling number of heads instead of head_size): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + implementation: AttentionImplementation = Field( + default=AttentionImplementation.auto, + desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.", + hint=FieldHint.feature, + ) + cross_document_attention: bool = Field( + default=True, + desc="Allow for cross-document attention.", + doc="Disable to prevent attention between tokens belonging to different documents.", + hint=FieldHint.feature, + ) def _validate(self) -> None: super()._validate() @@ -121,6 +137,3 @@ def layer_class(self) -> "type[Attention]": from fast_llm.layers.attention.attention import Attention return Attention - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 25fa2d91e..18c64acc4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -53,6 +53,13 @@ class LanguageModelEmbeddingsConfig(BlockConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + cross_document_position_embeddings: bool = Field( + default=True, + desc="Allow for cross-document position embeddings.", + doc="Disable to reset position ids at the beginning of each document.", + hint=FieldHint.feature, + ) + dropout: float = Field( default=0.0, desc="Dropout applied to the embedding layer.", diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 0ad3225c8..61ca1cfc0 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -136,9 +136,12 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None self._create_position_embeddings(kwargs[LanguageModelKwargs.sequence_length], batch.device) sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: + if not self._config.cross_document_position_embeddings: position_ids = torch.stack( - [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + [ + torch.cat([torch.arange(x) for x in sample_lens]) + for sample_lens in kwargs[LanguageModelKwargs.sequence_lengths] + ] ).to(batch.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] if kwargs[LanguageModelKwargs.sequence_first]: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a901a0466..c1ee246f7 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,12 +48,6 @@ class GPTBatchConfig(BatchConfig): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - # TODO: Find a better place for these? - cross_document_attention: bool = Field( - default=True, - desc="Applies attention to tokens from other documents in the packed sequence. Set to False for masking attention to other documents.", - hint=FieldHint.feature, - ) use_loss_masking_spans: bool = Field( default=False, desc="Read loss masking spans from the dataset.", diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 4b9849630..e16eac4de 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -226,10 +226,12 @@ def get_converters( class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter + hf_mixer_name: typing.ClassVar[str] = "mixer" class AprielMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter + hf_mixer_name: typing.ClassVar[str] = "mixer" class AprielBlockConverter: diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index bfc7d5569..a9a0909ec 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -2,6 +2,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.models.gpt.conversion.config import MistralCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaAttentionConverter, @@ -10,6 +11,7 @@ LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, + LlamaMLPConverter, ) from fast_llm.utils import safe_merge_dicts @@ -17,14 +19,20 @@ class MistralAttentionConverter(LlamaAttentionConverter): @classmethod def import_config(cls, config: dict) -> dict: - return safe_merge_dicts(super().import_config(config), {"window_size": config["sliding_window"]}) + config["attention_bias"] = False + return safe_merge_dicts( + super().import_config(config), + {"window_size": config["sliding_window"]}, + ) @classmethod def export_config(cls, config: AttentionConfig) -> dict: - return safe_merge_dicts( + out = safe_merge_dicts( super().export_config(config), {"sliding_window": config.window_size}, ) + del out["attention_bias"] + return out @classmethod def _check_config(cls, config: AttentionConfig) -> None: @@ -32,8 +40,23 @@ def _check_config(cls, config: AttentionConfig) -> None: assert not config.add_linear_biases +class MistrallMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + assert not config.add_linear_biases + out = super().export_config(config) + del out["mlp_bias"] + return out + + class MistralBlockConverter(LlamaBlockConverter): mixer_converter_class: typing.ClassVar[type[MistralAttentionConverter]] = MistralAttentionConverter + mlp_converter_class: typing.ClassVar[type[MistrallMLPConverter]] = MistrallMLPConverter class MistralDecoderConverter(LlamaDecoderConverter): diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index a76c3712e..34e38469a 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -5,7 +5,8 @@ import torch import transformers.modeling_outputs -from fast_llm.data.sample.gpt import GPTBatch +from fast_llm.data.sample.language_model import LanguageModelBatch +from fast_llm.data.sample.token import TokenBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM @@ -80,7 +81,9 @@ def inner_forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) batch = self.fast_llm_base_model.preprocess_batch( - GPTBatch(input_ids, sequence_lengths=sequence_lenghts), phase=PhaseType.inference, iteration=iteration + LanguageModelBatch(TokenBatch(input_ids, lengths=sequence_lenghts)), + phase=PhaseType.inference, + iteration=iteration, ) ((input_, kwargs),) = batch diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bd3c91a38..3295295f6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,7 +3,7 @@ import torch -from fast_llm.data.sample.gpt import GPTBatch +from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType @@ -40,7 +40,7 @@ def __init__( param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa def preprocess_meta( - self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType + self, batch_meta: GPTBatchConfig | LanguageModelBatch, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: # TODO Remove (Move batch splitting elsewhere) # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence @@ -51,7 +51,7 @@ def preprocess_meta( micro_sequence_length = batch_meta.micro_sequence_length truncate_documents = batch_meta.truncate_documents else: - micro_batch_size, sequence_length = batch_meta.shape + micro_batch_size, sequence_length = batch_meta.tokens.tokens.shape if phase != PhaseType.inference: sequence_length -= self._config.head.prediction_heads micro_sequence_length = sequence_length @@ -151,7 +151,7 @@ def preprocess_meta( def preprocess_batch( self, - batch: GPTBatch, + batch: LanguageModelBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, *, phase: PhaseType, @@ -161,19 +161,10 @@ def preprocess_batch( # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup - if preprocessed_meta is None: - preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) - - _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size - sequence_first = common_kwargs[AttentionKwargs.sequence_first] - max_prediction_distance = self._config.head.max_prediction_distance + batch.to_device_(self._distributed.device) - batch.token_ids = batch.token_ids.to( - device=self._distributed.device, - dtype=torch.int64, - non_blocking=True, - ) + if preprocessed_meta is None: + preprocessed_meta = self.preprocess_meta(batch, phase) reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): @@ -191,103 +182,59 @@ def preprocess_batch( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] - token_ids = batch.token_ids - if sequence_first: - # Move the sequence dimension first to make sequence parallel ops more efficient. - token_ids = token_ids.transpose(0, 1).contiguous() - preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size - if sequence_first: - tokens = token_ids[sequence_k - sequence_q : sequence_k] - else: - # TODO: Avoid multiple contiguous calls? - tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() - if batch.sequence_lengths is not None: - kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths - if batch.chosen_spans is not None: - kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans - if batch.rejected_spans is not None: - kwargs_meta[LanguageModelKwargs.rejected_spans] = batch.rejected_spans + tokens_end = kwargs_meta[AttentionKwargs.sequence_k_dim].size + tokens_begin = tokens_end - kwargs_meta[AttentionKwargs.sequence_q_dim].size + cropped_tokens = batch.tokens.crop(tokens_begin, tokens_end) # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents presents = None if i == len(preprocessed_meta) - 1 else [] - kwargs = { + + kwargs: dict[str, typing.Any] = { **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, + AttentionKwargs.sequence_lengths: batch.tokens.lengths, + **reference_logits[i], } + if phase != PhaseType.inference: - sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels - if sequence_first: - labels = token_ids[sequence_offset : sequence_k + max_prediction_distance] - else: - # TODO: Avoid multiple contiguous calls? - labels = token_ids[:, sequence_offset : sequence_k + max_prediction_distance].contiguous() - # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss - # TODO: take ignore_index from config + labels_begin = tokens_begin + 1 + labels_end = tokens_end + self._config.head.max_prediction_distance + + labels = batch.tokens.crop(labels_begin, labels_end).tokens + if batch.loss_masking_spans is not None: - # avoid changing input tokens - labels = labels.clone() - for idx, spans in enumerate(batch.loss_masking_spans): - if not spans.numel(): - continue - valid_spans = spans[ - (spans[:, 0] <= sequence_k + max_prediction_distance - 1) - & (spans[:, 1] >= sequence_offset) - ] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k + max_prediction_distance - 1) - valid_spans -= sequence_offset - loss_mask = torch.ones_like(labels, dtype=torch.bool) - for start, end in valid_spans: - if sequence_first: - loss_mask[start : end + 1, idx] = False - else: - loss_mask[idx, start : end + 1] = False - if self._config.output_layer.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - labels = torch.where(loss_mask, labels, -100) - kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) + loss_masking_spans = batch.loss_masking_spans.crop(labels_begin, labels_end) + loss_mask = torch.ones_like(labels, dtype=torch.bool) + for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): + for begin, end in loss_masking_spans: + loss_mask[sample_index, begin:end] = False + if self._config.output_layer.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + labels = torch.where(loss_mask, labels, -100) + + kwargs[LanguageModelKwargs.labels] = ( + labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels + ).contiguous() if batch.chosen_spans is not None: - chosen_valid_spans = [] - for spans in batch.chosen_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - chosen_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - - rejected_valid_spans = [] - for spans in batch.rejected_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - rejected_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - + kwargs[LanguageModelKwargs.chosen_spans] = batch.chosen_spans.crop(labels_begin, labels_end).ranges + + if batch.rejected_spans is not None: + kwargs[LanguageModelKwargs.rejected_spans] = batch.rejected_spans.crop( + labels_begin, labels_end + ).ranges + + tokens = ( + cropped_tokens.tokens.transpose(0, 1) + if kwargs[AttentionKwargs.sequence_first] + else cropped_tokens.tokens + ).contiguous() self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54ea13dc4..b8fb22ebb 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -27,7 +27,6 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, # OK since DPO is not supported for MTP. "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), - "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } diff --git a/tests/data/common.py b/tests/data/common.py index 3ade0e9bf..e6ab8a265 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -8,8 +8,14 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import IndexedDatasetConfig, SampledDatasetConfig, SamplingParameters -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig, GPTSamplingData, GPTSamplingParameters, ShufflingType +from fast_llm.data.dataset.config import ( + IndexedDatasetConfig, + SampledDatasetConfig, + SamplingConfig, + SamplingParameters, + ShufflingType, +) +from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.sampled import SampledIndexedDataset from fast_llm.data.sample.abstract import Sample @@ -35,7 +41,7 @@ def get_sampling_data( # Config with convenient defaults. distributed = Distributed(DistributedConfig(), use_cpu=True) return GPTSamplingData( - config=GPTSamplingConfig( + config=SamplingConfig( seed=seed, gpu=gpu, shuffle=shuffle, @@ -88,7 +94,7 @@ def get_test_data_and_compare_samples( expected_samples = {PhaseType.training.value.lower(): expected_samples} assert "sampling" not in config - config["sampling"] = GPTSamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) + config["sampling"] = SamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) data = GPTData(GPTDataConfig.from_dict(config), distributed_config) data.setup(distributed, sampling_parameters, cache_directory) with NoAutoValidate(): @@ -97,12 +103,15 @@ def get_test_data_and_compare_samples( batch_config.validate() tokens = { phase: torch.stack( - [batch.token_ids[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] + [ + batch.tokens.tokens[0] + for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0) + ] ) for phase, samples in samples_per_dataset.items() } for phase, expected_samples_ in expected_samples.items(): - Assert.all_equal(tokens[phase], expected_samples_) + Assert.all_equal(tokens[phase].to(torch.int64), expected_samples_) return data @@ -117,27 +126,30 @@ def compare_indexed_dataset( sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get_document(i).token_ids) for i in range(min(len(dataset), 100))], + [len(dataset.get_document(i).tokens.tokens) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get_document(i).token_ids, np.array(expected_sample, dtype=np.uint16)) + Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample, dtype=np.int64)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): - Assert.all_equal( + print(i) + Assert.eq( dataset.get_document( i, parameters=GPTSamplingParameters( num_samples=0, sequence_length=0, vocab_size=0, use_loss_masking_spans=True ), - ).loss_masking_spans, - np.array(loss_masking_spans[i], dtype=np.int32).reshape(-1, 2), + ).loss_masking_spans.ranges, + loss_masking_spans[i], ) def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal(torch.stack([sampled[i].token_ids for i in range(len(expected_samples))]), expected_samples) + Assert.all_equal( + torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]).to(torch.int64), expected_samples + ) def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): @@ -161,7 +173,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s ) seen_tokens = 0 for document_index in document_sampling: - document = sampled._indexed_dataset.get_document(document_index).token_ids + document = sampled._indexed_dataset.get_document(document_index).tokens.tokens all_tokens[seen_tokens : seen_tokens + len(document)] = document[: num_tokens - seen_tokens] seen_tokens += len(document) @@ -172,7 +184,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s all_tokens[index * sampled._parameters.sequence_length : (index + 1) * sampled._parameters.sequence_length + 1] for index in range(sampled._parameters.num_samples) ] - token_ids = torch.stack([sampled[i].token_ids for i in range(len(sampled))]) + token_ids = torch.stack([sampled[i].tokens.tokens for i in range(len(sampled))]).to(torch.int64) Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 678bffa21..0099cb50b 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -4,7 +4,7 @@ import pytest from fast_llm.data.dataset.config import BlendedDatasetConfig -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( compare_sampled_dataset, @@ -123,7 +123,7 @@ def test_gpt_blended(): ], "weights": [0.75, 0.25], }, - BlendedDatasetConfig[GPTSample], + BlendedDatasetConfig[LanguageModelSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) @@ -162,7 +162,7 @@ def test_gpt_blended_mixed(): ], "weights": [0.6, 0.4], }, - BlendedDatasetConfig[GPTSample], + BlendedDatasetConfig[LanguageModelSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index bb4905cb6..5335e01c0 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,5 +1,5 @@ from fast_llm.data.dataset.config import ConcatenatedDatasetConfig -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset, compare_sampled_dataset, @@ -28,7 +28,7 @@ def test_gpt_concatenate(): get_test_dataset() dataset = get_dataset_config( {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, - ConcatenatedDatasetConfig[GPTSample], + ConcatenatedDatasetConfig[LanguageModelSample], ).build() compare_indexed_dataset( dataset, diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index 1286bddd7..ca887f3c1 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -27,8 +27,8 @@ def test_gpt_memmap(cache_directory): MEMMAP_DATASET_SPANS = { 9: [], - 10: [[0, 4], [6, 8]], - 13: [[1, 2]], + 10: [(0, 2), (2, 7), (7, 10)], + 13: [(0, 2)], 15: [], } diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 388726bfb..601abcf99 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -11,7 +11,7 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.data.common import MockGPTMemmapDatasetConfig # Noqa @@ -31,59 +31,44 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): documents = [ - GPTSample(torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype))) + (torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)), None, None, None) for _ in range(100) ] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, document in enumerate(documents): - assert np.array_equal( - dataset.get_document(i).token_ids, document.token_ids, equal_nan=True - ), f"Mismatch for document {i}: {document} != {dataset.get_document(i)}." + for i, (tokens, _, _, _) in enumerate(documents): + Assert.all_equal(dataset.get_document(i).tokens.tokens, tokens.to(torch.int64)) -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_preference_dataset(dtype): - def generate_valid_span(max_seq_length): - span = np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False) - return torch.from_numpy(np.sort(span)) +def _generate_valid_span(max_seq_length): + return np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist() - vocab_size = 1000 - max_seq_length = 8192 - num_samples = 100 +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_write_memmap_preference_dataset(dtype): documents = [ - GPTSample( - token_ids=torch.from_numpy(np.random.randint(vocab_size, size=max_seq_length).astype(dtype)), - chosen_span=generate_valid_span(max_seq_length=max_seq_length), - rejected_span=generate_valid_span(max_seq_length=max_seq_length), + ( + torch.from_numpy(np.random.randint(1000, size=100).astype(dtype)), + None, + _generate_valid_span(100), + _generate_valid_span(100), ) - for _ in range(num_samples) + for _ in range(50) ] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, document in enumerate(documents): - dataset_item = dataset.get_document( - i, - parameters=GPTSamplingParameters( - num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True - ), - ) - assert np.array_equal( - dataset_item.token_ids, document.token_ids, equal_nan=True - ), f"Token ids mismatch for document {i}: {document} != {dataset.get_document(i)}." - - assert np.array_equal( - dataset_item.chosen_span, document.chosen_span, equal_nan=True - ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get_document(i).chosen_span}." - - assert np.array_equal( - dataset_item.rejected_span, document.rejected_span, equal_nan=True - ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get_document(i).rejected_span}." + parameters = GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True + ) + for i, (token_ids, _, (chosen_begin, chosen_end), (rejected_begin, rejected_end)) in enumerate(documents): + document = dataset.get_document(i, parameters=parameters) + Assert.all_equal(document.tokens.tokens, token_ids.to(torch.int64)) + Assert.eq(document.chosen_spans.ranges, [(chosen_begin, chosen_end + 1)]) + Assert.eq(document.rejected_spans.ranges, [(rejected_begin, rejected_end + 1)]) def test_load_metadata_from_hub(): @@ -136,7 +121,7 @@ def test_absent_metadata_local(): def test_split_dataset(): - dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0], {"training": 3, "validation": 1}, @@ -164,8 +149,8 @@ def test_split_dataset(): def test_split_datasets_0(): - dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 1, "validation": 1}, @@ -183,8 +168,8 @@ def test_split_datasets_0(): def test_split_datasets_1(): - dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index d7b3021fe..58f4d3dab 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -2,9 +2,11 @@ import pytest import torch -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters, ShufflingType +from fast_llm.data.dataset.config import ShufflingType +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert from tests.data.common import ( get_dataset_config, @@ -61,7 +63,7 @@ def test_gpt_sampled_data(): ) -class SimpleGPTIndexedDataset[SampleType: GPTSample](IndexedDataset[SampleType]): +class SimpleGPTIndexedDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): # TODO: worth adding to the main codebase? def __init__(self, samples): self._samples = samples @@ -71,7 +73,7 @@ def get_document( ) -> SampleType: if end is None: end = len(self._samples[index]) - return GPTSample(token_ids=torch.tensor(self._samples[index][begin:end], dtype=torch.int64)) + return LanguageModelSample(TokenSample(torch.tensor(self._samples[index][begin:end], dtype=torch.int64))) def __len__(self) -> int: return len(self._samples) @@ -178,4 +180,4 @@ def test_gpt_sample_padding(): else: sampled = dataset.sample(sampling) for idx in range(len(expected_samples)): - Assert.all_equal(sampled[idx].token_ids, np.array(expected_samples[idx])) + Assert.all_equal(sampled[idx].tokens.tokens, np.array(expected_samples[idx])) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index e83387a24..3c6ae10d4 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,5 +1,5 @@ from fast_llm.data.dataset.config import DatasetSliceConfig -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -35,7 +35,7 @@ def test_gpt_slice(): # samples[9:18] dataset = get_dataset_config( {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, - DatasetSliceConfig[GPTSample], + DatasetSliceConfig[LanguageModelSample], ).build() compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 3fae970f8..489f5e1c1 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,167 +1,80 @@ -import random - import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.functional.dpo import _compute_dpo_loss, _compute_logprobs_for_preference_spans +from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert +from tests.utils.dataset import get_random_spans from tests.utils.utils import requires_cuda -def ref_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: - if temperature != 1.0: - logits.div_(temperature) - batch_dim = logits.shape[:-1] - last_dim = logits.shape[-1] - - output = torch.nn.functional.cross_entropy(logits.reshape(-1, last_dim), labels.reshape(-1), reduction="none") - log_probs_labels = -output.view(*batch_dim) - - return log_probs_labels - - -def ref_packed_get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - attention_mask, - prompt_id_lens, - packed_seq_lens, -) -> torch.FloatTensor: - labels = labels[:, 1:] - logits = logits[:, :-1, :] - per_token_logps = ref_log_probs_from_logits(logits, labels) - - loss_masks = attention_mask.clone().bool() - - index = 0 - for i, seq_len in enumerate(packed_seq_lens): - loss_masks[0, index : index + prompt_id_lens[i]] = False - index = index + seq_len - - loss_masks = loss_masks[:, 1:] - - logprobs_sums = [] - index = 0 - for i, seq_len in enumerate(packed_seq_lens): - seq = per_token_logps[0, index : index + seq_len - 1] - mask = loss_masks[0, index : index + seq_len - 1] - logprobs_sums.append((seq * mask).sum()) - index = index + seq_len - chosen_logps = logprobs_sums[: len(packed_seq_lens) // 2] - rejected_logps = logprobs_sums[len(packed_seq_lens) // 2 :] - - return torch.tensor(chosen_logps), torch.tensor(rejected_logps) - - -@pytest.mark.slow -@pytest.mark.parametrize( - ("batch_size", "seq_length", "vocab_size"), - ( - (2, 32, 50), - (1, 32, 50), - (2, 100, 50), - (2, 32, 200), - ), -) -def test_preference_logps(batch_size, seq_length, vocab_size): - random.seed(0) - torch.manual_seed(0) - - def random_split(seq_length): - min_val = int(seq_length * 0.3) - max_val = int(seq_length * 0.7) - - if max_val < min_val: - max_val = min_val - - a = random.randint(min_val, max_val) - b = seq_length - a - return [a, b] - - logits = torch.randn(batch_size, seq_length, vocab_size) - targets = torch.randint(0, vocab_size, (batch_size, seq_length)) - packed_seq_lens = random_split(seq_length) # simulate different chosen/rejected lengths - prompt_id_lens = [int(min(packed_seq_lens) * 0.75)] * 2 # sequences are 75% prompt 25% generation - attention_mask = torch.tensor([1] * packed_seq_lens[0] + [2] * packed_seq_lens[1]).unsqueeze(0) - - chosen_span = torch.tensor([[prompt_id_lens[0], packed_seq_lens[0] - 1]]) - 1 # shift by 1 due to label shifting - rejected_span = ( - torch.tensor([[packed_seq_lens[0] + prompt_id_lens[1], packed_seq_lens[0] + packed_seq_lens[1] - 1]]) - 1 - ) # shift by 1 due to label shifting - - ref_chosen_logps, ref_rejected_logps = ref_packed_get_batch_logps( - logits, targets, attention_mask, prompt_id_lens, packed_seq_lens +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans ) - chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( - logits=logits, - targets=targets[:, 1:], - chosen_spans=chosen_span, - rejected_spans=rejected_span, - ) - - ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) - - # check all logps - Assert.custom(torch.allclose, ref_logps, selected_log_probs, rtol=1e-5) - # check chosen and rejected summed logps - Assert.custom(torch.allclose, ref_chosen_logps, chosen_logps, rtol=1e-5) - Assert.custom(torch.allclose, ref_rejected_logps, rejected_logps, rtol=1e-5) - - -def ref_dpo_loss_fcn( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta=1, - label_smoothing=0, +def reference_dpo_loss( + logits: torch.Tensor, + targets: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, + beta: float, ) -> torch.Tensor: + # TODO: Too similar to the actual implementation. + policy_log_probs = ( + torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + ) + policy_chosen_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + policy_rejected_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + reference_log_probs = ( + torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) + .gather(dim=-1, index=targets.unsqueeze(-1)) + .squeeze(-1) + ) + reference_chosen_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + reference_rejected_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps - logits = pi_logratios - ref_logratios - - # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) - losses = ( - -torch.nn.functional.logsigmoid(beta * logits) * (1 - label_smoothing) - - torch.nn.functional.logsigmoid(-beta * logits) * label_smoothing - ) - - loss = losses.mean() - - return loss + return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() def test_dpo_loss(): torch.manual_seed(0) + logits = torch.randn((10, 50, 100), requires_grad=True) + reference_model_logits = torch.randn((10, 50, 100)) + targets = torch.randint(0, 100, (10, 50)) - NUM_SAMPLES = 20 - policy_chosen_logps = torch.rand(NUM_SAMPLES) - policy_rejected_logps = torch.rand(NUM_SAMPLES) - reference_chosen_logps = torch.rand(NUM_SAMPLES) - reference_rejected_logps = torch.rand(NUM_SAMPLES) - betas = torch.rand(NUM_SAMPLES) + spans = get_random_spans(10, 10, 50) - for i in range(NUM_SAMPLES): - fastllm_dpo_loss = _compute_dpo_loss( - policy_chosen_logps=policy_chosen_logps[i], - policy_rejected_logps=policy_rejected_logps[i], - reference_chosen_logps=reference_chosen_logps[i], - reference_rejected_logps=reference_rejected_logps[i], - beta=betas[i].item(), - ) - ref_dpo_loss = ref_dpo_loss_fcn( - policy_chosen_logps=policy_chosen_logps[i].unsqueeze(0), - policy_rejected_logps=policy_rejected_logps[i].unsqueeze(0), - reference_chosen_logps=reference_chosen_logps[i].unsqueeze(0), - reference_rejected_logps=reference_rejected_logps[i].unsqueeze(0), - beta=betas[i].item(), - ) - Assert.rms_close(fastllm_dpo_loss, ref_dpo_loss, 1e-5) + fastllm_loss, fast_llm_grad = compute_dpo_loss( + logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 + ) + reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1) + reference_loss.backward() + Assert.rms_close(fastllm_loss, reference_loss, 1e-5) + Assert.rms_close(fast_llm_grad, logits.grad, 1e-5) @requires_cuda diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index f057c037f..7447e395a 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -3,7 +3,6 @@ import numpy as np import pytest -import torch from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset @@ -11,7 +10,7 @@ from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingData from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.sampled import logger -from fast_llm.data.sample.gpt import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset @@ -144,18 +143,16 @@ def __getitem__(self, idx: int) -> typing.Any: shuffled_idx = self._shuffle_idx[idx] doc_f, offset_f = self._sample_idx[shuffled_idx] doc_l, offset_l = self._sample_idx[shuffled_idx + 1] - sample_list = [ - self._indexed_dataset.get_document( - self._doc_idx[doc].item(), - begin=(doc == doc_f) * offset_f, - end=offset_l + 1 if doc == doc_l else None, - ) - for doc in range(doc_f, doc_l + 1) - ] - token_ids = torch.cat([sample.token_ids for sample in sample_list]) - Assert.eq(len(token_ids), self._sequence_length + 1) - - return GPTSample(token_ids=token_ids) + return LanguageModelSample.from_documents( + [ + self._indexed_dataset.get_document( + self._doc_idx[doc].item(), + begin=(doc == doc_f) * offset_f, + end=offset_l + 1 if doc == doc_l else None, + ) + for doc in range(doc_f, doc_l + 1) + ] + ) @property def name(self) -> str: diff --git a/tests/test_attention.py b/tests/test_attention.py index a19cba8f0..b86cc95fa 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -3,7 +3,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.attention.attention import Attention -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert @@ -29,7 +29,7 @@ def test_varlen_preprocessing(): micro_sequence_length = 12 sequence_length = 36 attention = Attention( - AttentionConfig(head_size=64), + AttentionConfig(head_size=64, implementation=AttentionImplementation.flash, cross_document_attention=False), DistributedConfig(compute_dtype="bfloat16"), hidden_dim=TensorDim("", 1), lr_scale=None, diff --git a/tests/test_config.py b/tests/test_config.py index 63f2606f1..9a1f542a0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,7 +6,7 @@ import yaml from fast_llm.config import NoAutoValidate -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.config import SamplingConfig from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig @@ -60,7 +60,7 @@ def test_validate_example_config(): GPTTrainerConfig.from_dict(fast_llm_config_dict) -@pytest.mark.parametrize("cls", (GPTSamplingConfig, GPTModelConfig)) +@pytest.mark.parametrize("cls", (SamplingConfig, GPTModelConfig)) def test_serialize_default_config_updates(cls): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index b43923f4d..428dec56b 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -6,7 +6,6 @@ import yaml from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.sample.gpt import GPTSample from tests.utils.global_variables import ( DATASET_PREFIX, MODEL_DATASET_PREFIX, @@ -26,6 +25,15 @@ def download_santacoder_tokenizer(): transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) +def get_random_spans(num_samples: int, max_spans: int, lengths: np.ndarray | int, seed: int = 0): + spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths, [num_samples, max_spans * 2])) + spans = [np.unique(sample_spans).tolist() for sample_spans in spans] + return [ + [(begin, end) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] + for sample_spans in spans + ] + + def get_test_dataset( prefix: pathlib.Path = DATASET_PREFIX, seed: int = 1234, @@ -47,15 +55,27 @@ def get_test_dataset( tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) samples = [ - GPTSample(torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size)) + ( + torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size), + None, + None, + None, + ) for document in texts ] if max_spans > 0: - lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) - spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) - for sample, span in zip(samples, spans): - span = np.unique(span) - sample.loss_masking_spans = torch.from_numpy(span[: len(span) // 2 * 2].reshape(-1, 2)) + spans = get_random_spans( + len(samples), max_spans, np.array([[max(len(tokens), 1)] for tokens, _, _, _ in samples]), seed + ) + samples = [ + ( + tokens, + torch.tensor(sample_spans, dtype=torch.int32).reshape(-1, 2), + None, + None, + ) + for (tokens, _, _, _), sample_spans in zip(samples, spans, strict=True) + ] GPTMemmapDataset.write_dataset(prefix, samples) yaml.safe_dump(