diff --git a/fast_llm/config.py b/fast_llm/config.py index 9644df9c1..658ad5666 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -492,6 +492,10 @@ def _validate_element(cls, value, type_, name: str): value = cls._validate_dict(value, type_, name) elif origin is type: value = cls._validate_type(value, type_, name) + elif issubclass(origin, Config): + # TODO: Validate arguments for config generics. + cls._validate_element_type(value, type_.__origin__, strict=False) + value.validate(_is_validating=True) else: raise FieldTypeError(f"Unsupported __origin__ `{origin}`") elif not isinstance(type_, type): @@ -806,6 +810,8 @@ def _from_dict_nested(cls, value, type_, strict: bool): value = cls._from_dict_array(value, type_, strict) elif issubclass(origin, dict): value = cls._from_dict_dict(value, type_, strict) + elif issubclass(origin, Config): + value = cls._from_dict_config(value, type_, strict) elif origin is type: pass else: @@ -813,10 +819,15 @@ def _from_dict_nested(cls, value, type_, strict: bool): elif not isinstance(type_, type): raise FieldTypeError(f"Not a type: {type_}.") elif issubclass(type_, Config): - if value is MISSING: - value = {} - if isinstance(value, dict): - value = type_._from_dict(value, strict) + value = cls._from_dict_config(value, type_, strict) + return value + + @classmethod + def _from_dict_config(cls, value, type_, strict: bool): + if value is MISSING: + value = {} + if isinstance(value, dict): + value = type_._from_dict(value, strict) return value @classmethod @@ -938,6 +949,7 @@ def __init_subclass__(cls): We need to postpone validation until the class has been processed by the dataclass wrapper. """ Assert.eq(cls.__name__, cls.__qualname__) + super().__init_subclass__() for base_class in cls.__mro__: if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( @@ -1006,6 +1018,7 @@ def __init__(self, config: ConfigType, *args, **kwargs): def __init_subclass__(cls): # Automatically set `config_class` based on the bound type. # Make sure `ConfigType` is bound and respects class hierarchy. + super().__init_subclass__() try: config_class = None for base in types.get_original_bases(cls): diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 4c041945d..633367c80 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,9 +1,13 @@ import enum import pathlib +import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.tokenizer import Tokenizer + class MultiprocessingContext(str, enum.Enum): # Fast but risk of segfaults due to interactions with triton @@ -29,7 +33,7 @@ class TokenizerConfig(Config): hint=FieldHint.deprecated, valid=check_field(Assert.eq, TokenizerFromFile), ) - path: pathlib.Path | None = Field( + path: pathlib.Path = Field( default=None, desc="Path to the tokenizer file.", hint=FieldHint.core, @@ -39,3 +43,8 @@ class TokenizerConfig(Config): desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", hint=FieldHint.core, ) + + def get_tokenizer(self) -> "Tokenizer": + from fast_llm.data.tokenizer import Tokenizer + + return Tokenizer(self) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index e24d39985..c67dc0321 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -5,6 +5,7 @@ from fast_llm.config import Configurable from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.sample.abstract import Batch from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig @@ -47,5 +48,5 @@ def get_iterator( num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[typing.Any]: + ) -> typing.Iterator[Batch]: pass diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index efee46959..5083c5121 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,9 +1,11 @@ import logging from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class -from fast_llm.data.config import MultiprocessingContext, TokenizerConfig +from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig +from fast_llm.data.dataset.config import SampledDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -19,12 +21,8 @@ class GPTDataConfig(DataConfig): _abstract = False - tokenizer: TokenizerConfig = Field( - desc="Configuration for the tokenizer (for FIM).", - hint=FieldHint.feature, - ) # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, GPTSampledDatasetConfig] = Field( + datasets: dict[str, SampledDatasetConfig[GPTSample]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..2a18afd50 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,11 +1,9 @@ -import dataclasses import logging import pathlib import typing import warnings from functools import partial -import numpy as np import torch import torch.utils.data @@ -14,43 +12,32 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.sample.gpt import GPTBatch, GPTSample from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) -@dataclasses.dataclass -class GPTBatch: - token_ids: torch.Tensor - loss_masking_spans: list[torch.Tensor] | None = None - sequence_lengths: list[torch.Tensor] | None = None - chosen_spans: list[torch.Tensor] | None = None - rejected_spans: list[torch.Tensor] | None = None - - def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: - stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None stacked_chosen_spans = None stacked_rejected_spans = None if sampling_parameters.use_loss_masking_spans: - stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + stacked_spans = [sample.loss_masking_spans for sample in batch] if sampling_parameters.use_preference_loss_spans: - stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch] - stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] + stacked_chosen_spans = [sample.chosen_span for sample in batch] + stacked_rejected_spans = [sample.rejected_span for sample in batch] if not sampling_parameters.cross_document_attention: - sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + sequence_lengths = [sample.sequence_lengths for sample in batch] return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), + token_ids=torch.stack([sample.token_ids for sample in batch]), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, @@ -67,7 +54,6 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, GPTSamplingParameters] - _tokenizer: Tokenizer | None _is_setup: bool = False def __init__( @@ -108,7 +94,6 @@ def setup( ) log_main_rank(f"Preparing dataset. This may take several minutes.") - self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer) if self._cache_directory is None: # TODO: Avoid this @@ -116,11 +101,6 @@ def setup( self._datasets = {} for dataset_name, sampling_parameters in self._sampling_parameters.items(): - if self._tokenizer is not None: - # NOTE: Some models like Qwen2-1.5B-Instruct - # have vocab_size bigger in model config than in tokenizer - # TODO: Still, is it too constraining? - Assert.geq(sampling_parameters.vocab_size, self._tokenizer.vocab_size) if sampling_parameters.num_samples > 0: sampling = GPTSamplingData( config=self._config.sampling, @@ -128,7 +108,6 @@ def setup( cache_directory=self._cache_directory, distributed=distributed, dataset_name=dataset_name, - tokenizer=self._tokenizer, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) @@ -136,21 +115,16 @@ def setup( safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True - @property - def tokenizer(self) -> Tokenizer: - assert self._is_setup - return self._tokenizer - def get_iterator( self, - batch_config: BatchConfig, + batch_config: GPTBatchConfig, dataset_name: str, *, consumed_samples: int, num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[typing.Any]: + ) -> typing.Iterator[GPTBatch]: assert self._is_setup # Some dataset names may come from phases and are capitalized, diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index b470c0159..33942708b 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,11 +1,13 @@ import abc import typing +from fast_llm.data.sample.abstract import Sample + if typing.TYPE_CHECKING: from fast_llm.data.dataset.config import SamplingData -class Dataset(abc.ABC): +class Dataset[SampleType: Sample](abc.ABC): """ A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature. """ @@ -17,15 +19,23 @@ def name(self) -> str: A name for the dataset to facilitate identification and debugging. """ + def __getstate__(self): + state = super().__getstate__() + # Pickling sometimes fails with bound `SampleType`. + # This is not needed at runtime, so we just drop it. + if "__orig_class__" in state: + del state["__orig_class__"] + return state + -class SampledDataset(Dataset): +class SampledDataset[SampleType: Sample](Dataset[SampleType]): """ A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training. (See the `Sampler` class below.) """ @abc.abstractmethod - def __getitem__(self, index: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: pass @abc.abstractmethod @@ -33,8 +43,8 @@ def __len__(self) -> int: pass -class SamplableDataset(Dataset): +class SamplableDataset[SampleType: Sample](Dataset[SampleType]): @abc.abstractmethod - def sample(self, config: "SamplingData") -> SampledDataset: + def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 24b0fa76f..264eb373d 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -1,16 +1,16 @@ import logging -import typing -import numpy as np +import torch from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingData +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities logger = logging.getLogger(__name__) -class BlendedDataset(SampledDataset): +class BlendedDataset[SampleType: Sample](SampledDataset[SampleType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -21,7 +21,7 @@ class BlendedDataset(SampledDataset): def __init__( self, name: str, - datasets: list[SampledDataset], + datasets: list[SampledDataset[SampleType]], weights: list[float], sampling_config: SamplingData, ): @@ -29,51 +29,52 @@ def __init__( assert len(datasets) > 0 Assert.eq(len(datasets), len(weights)) self._datasets = datasets - self._weights = np.array(normalize_probabilities(weights)) + self._weights = torch.from_numpy(normalize_probabilities(weights, return_array=True)) self._num_samples = sampling_config.parameters.num_samples def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: """ Blending is typically done in one of the following iterative way (ex. in Megatron datasets): ```python dataset_index=np.zeros(num_samples) sample_index=np.zeros(num_samples) sampled=np.zeros(len(weights)) - for idx in range(num_samples): - error = weights * (idx + 1) - sampled + for index in range(num_samples): + error = weights * (index + 1) - sampled dataset_index_ = np.argmax(error) - dataset_index[idx] = dataset_index_ - sample_index[idx] = sampled[dataset_index_] + dataset_index[index] = dataset_index_ + sample_index[index] = sampled[dataset_index_] sampled[dataset_index_] +=1 ``` I.e. it iteratively picks samples to minimize the error `weights * sum(sampled) - sampled`. This implementation computes values on the fly instead of pre-computing them all. """ # We find the number of samples taken from each dataset prior to this point. - sampled = self._get_sampled(idx) + sampled = self._get_sampled(index) # Then get the present sample. - dataset_index = self._get_next_dataset(idx, sampled) - return self._datasets[dataset_index][sampled[dataset_index]] + dataset_index = self._get_next_dataset(index, sampled) + return self._datasets[dataset_index][sampled[dataset_index].item()] - def _get_sampled(self, num_samples: int): + def _get_sampled(self, num_samples: int) -> torch.Tensor: # First we determine a lower bound. # This is indeed a lower bound because a lower value for one dataset would involve more sampling below, # and it would be from that same dataset because it would have the highest error, - sampled = np.floor(self._weights * num_samples).astype(int) + + sampled = (self._weights * num_samples).to(torch.int64) # Then we sample until we reach the target number of samples. # This may not match the actual sampling order, but the final value of `sampled` is correct. - for idx in range(sampled.sum(), num_samples): - dataset_index = self._get_next_dataset(idx, sampled) + for index in range(sampled.sum().item(), num_samples): + dataset_index = self._get_next_dataset(index, sampled) sampled[dataset_index] += 1 return sampled - def _get_next_dataset(self, idx, sampled): + def _get_next_dataset(self, index: int, sampled: torch.Tensor) -> int: # The next sample is the one with the highest error. - return (self._weights * (idx + 1) - sampled).argmax() + return (self._weights * (index + 1) - sampled).argmax().item() @property - def name(self): + def name(self) -> str: return self._name diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 0c1b0cd09..7a8d3567d 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -7,6 +7,7 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: @@ -64,37 +65,38 @@ def get_next_rank(self) -> int: @config_class() -class DatasetConfig(Config): +class DatasetConfig[SampleType: Sample](Config): _abstract: typing.ClassVar[bool] = True -@config_class() -class SampledDatasetConfig(DatasetConfig): +@config_class(registry=True) +class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. """ - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + # TODO: ====== `SamplingData` contains more than needed (ex. `num_samples`) raise NotImplementedError() @config_class() -class SamplableDatasetConfig(SampledDatasetConfig): - def build(self) -> SamplableDataset: +class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): + def build(self) -> SamplableDataset[SampleType]: raise NotImplementedError() - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: return self.build().sample(sampling) @config_class() -class IndexedDatasetConfig(SamplableDatasetConfig): - def _build(self) -> "IndexedDataset": +class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): + def build(self) -> "IndexedDataset[SampleType]": raise NotImplementedError() -@config_class() -class ConcatenatedDatasetConfig(SamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "concatenated"}) +class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): """ Concatenate multiple indexed datasets as if they were one. TODO: Make a post-sampling version? (staged training) @@ -106,7 +108,7 @@ class ConcatenatedDatasetConfig(SamplableDatasetConfig): desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[IndexedDatasetConfig] = Field( + datasets: list[IndexedDatasetConfig[SampleType]] = Field( default_factory=list, desc="The datasets to concatenate.", hint=FieldHint.core, @@ -122,8 +124,8 @@ def _build[T: ConcatenatedDataset](self, cls: type[T]) -> T: return cls(self.name, [dataset.build() for dataset in self.datasets]) -@config_class() -class DatasetSliceConfig(SamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "slice"}) +class DatasetSliceConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): """ Use a fraction of an indexed dataset, specified by the range (begin, end). Typically used to subsample a dataset, or to reserve part of the dataset for validation and/or testing. @@ -133,7 +135,7 @@ class DatasetSliceConfig(SamplableDatasetConfig): """ _abstract = False - dataset: IndexedDatasetConfig = Field( + dataset: IndexedDatasetConfig[SampleType] = Field( default=None, desc="The dataset to split.", hint=FieldHint.core, @@ -152,12 +154,9 @@ class DatasetSliceConfig(SamplableDatasetConfig): def build(self) -> "DatasetSlice": from fast_llm.data.dataset.indexed import DatasetSlice - return self._build(DatasetSlice) - - def _build[T: DatasetSlice](self, cls: type[T]) -> T: dataset = self.dataset.build() size = len(dataset) - return cls( + return DatasetSlice[SampleType]( f"{dataset.name}_{self.begin}_{self.end}", dataset, round(self.begin * size), @@ -165,8 +164,8 @@ def _build[T: DatasetSlice](self, cls: type[T]) -> T: ) -@config_class() -class SampledDatasetUpdateConfig(SampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "sampled"}) +class SampledDatasetUpdateConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): """ Wrap a dataset to explicitly sample from it and optionally update its configuration parameters. Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument. @@ -177,24 +176,24 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) - dataset: SampledDatasetConfig = Field( + dataset: SampledDatasetConfig[SampleType] = Field( desc="The dataset to sample from.", hint=FieldHint.core, ) - def build_and_sample(self, data: SamplingData) -> SampledDataset: + def build_and_sample(self, data: SamplingData) -> SampledDataset[SampleType]: return self.dataset.build_and_sample(data.update_config(self.sampling)) -@config_class() -class BlendedDatasetConfig(SampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "blended"}) +class BlendedDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): _abstract = False name: str = Field( default="blended", desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[SampledDatasetConfig] = Field( + datasets: list[SampledDatasetConfig[SampleType]] = Field( default_factory=list, desc="The datasets to blend.", hint=FieldHint.core, @@ -214,7 +213,7 @@ def _validate(self) -> None: def build_and_sample( self, sampling: SamplingData, - ) -> SampledDataset: + ) -> SampledDataset[SampleType]: from fast_llm.data.dataset.blended import BlendedDataset # Build and sample the datasets. @@ -235,7 +234,7 @@ def build_and_sample( for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) ] # Blend the datasets. - return BlendedDataset( + return BlendedDataset[SampleType]( self.name, sampled_datasets, self.weights, diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 656cd7d24..36412b6ce 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -6,27 +6,23 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import ( - BlendedDatasetConfig, - ConcatenatedDatasetConfig, - DatasetSliceConfig, IndexedDatasetConfig, SamplableDatasetConfig, SampledDatasetConfig, - SampledDatasetUpdateConfig, SamplingConfig, SamplingData, SamplingParameters, ) +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset - from fast_llm.data.tokenizer import Tokenizer class ShufflingType(str, enum.Enum): @@ -86,27 +82,10 @@ class GPTSamplingData(SamplingData): config: GPTSamplingConfig parameters: GPTSamplingParameters - tokenizer: "Tokenizer" -@config_class(registry=True) -class GPTSampledDatasetConfig(SampledDatasetConfig): - pass - - -@config_class() -class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig): - pass - - -@config_class() -class GPTIndexedDatasetConfig(GPTSamplableDatasetConfig, IndexedDatasetConfig): - def build(self) -> "GPTIndexedDataset": - raise NotImplementedError() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "random"}) +class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -120,8 +99,8 @@ def build(self) -> "GPTRandomDataset": return GPTRandomDataset(self.name) -@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"}) -class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) +class GPTMemmapDatasetConfig[SampleType: GPTSample](IndexedDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -145,43 +124,8 @@ def build(self) -> "GPTMemmapDataset": return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) -@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) -class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() - - def build(self) -> "GPTConcatenatedDataset": - from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset - - return self._build(GPTConcatenatedDataset) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"}) -class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - dataset: GPTIndexedDatasetConfig = FieldUpdate() - - def build(self) -> "GPTDatasetSlice": - from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice - - return self._build(GPTDatasetSlice) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"}) -class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): - _abstract = False - sampling: GPTSamplingConfig = FieldUpdate() - dataset: GPTSampledDatasetConfig = FieldUpdate() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"}) -class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): - _abstract: typing.ClassVar[bool] = False - datasets: list[GPTSampledDatasetConfig] = FieldUpdate() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "file"}) -class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "file"}) +class GPTDatasetFromFileConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -189,18 +133,18 @@ class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: config = self._load_config() return config.build_and_sample(sampling) - def build(self) -> SamplableDataset: + def build(self) -> SamplableDataset[SampleType]: config = self._load_config() - assert isinstance(config, GPTSamplableDatasetConfig) + assert isinstance(config, SamplableDatasetConfig) return config.build() - def _load_config(self): + def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." - return GPTSampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) + return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) def _convert_paths(self, config): # Recursively convert paths relative to `self.path.parent` to make them relative to cwd. @@ -224,6 +168,10 @@ class FimConfig(Config): Configuration for FIM. """ + tokenizer: TokenizerConfig = Field( + desc="Configuration for the tokenizer.", + hint=FieldHint.feature, + ) rate: float = Field( # TODO: Use meaningful default now that fim is a wrapper? default=0.0, @@ -286,15 +234,15 @@ class FimConfig(Config): ) -@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"}) -class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): +@config_class(dynamic_type={SampledDatasetConfig: "fim"}) +class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType], FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - dataset: GPTSampledDatasetConfig = Field( + dataset: SampledDatasetConfig = Field( default=None, desc="The dataset to wrap with fim.", hint=FieldHint.core, @@ -302,15 +250,15 @@ class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): def build_and_sample( self, - sampling: GPTSamplingData, + sampling: SamplingData, ) -> SampledDataset: from fast_llm.data.dataset.gpt.fim import GPTFimDataset return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling) -@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) -class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "test_slow"}) +class GPTTestSlowDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType]): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. """ @@ -323,8 +271,8 @@ class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: assert sampling.distributed.config.world_size > 1 if sampling.distributed.config.rank == 0: time.sleep(self.sleep) - return GPTRandomDatasetConfig().build_and_sample(sampling) + return GPTRandomDatasetConfig[SampleType]().build_and_sample(sampling) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3be..175a0e549 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -1,12 +1,13 @@ import numpy as np +import torch from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.sample.gpt import GPTSample from fast_llm.engine.distributed.config import MAX_SEED -class GPTFimDataset(SampledDataset): +class GPTFimDataset[SampleType: GPTSample](SampledDataset[SampleType]): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py @@ -15,7 +16,7 @@ class GPTFimDataset(SampledDataset): def __init__( self, config: FimConfig, - dataset: SampledDataset, + dataset: SampledDataset[SampleType], sampling: GPTSamplingData, ): if sampling.parameters.use_loss_masking_spans: @@ -26,7 +27,7 @@ def __init__( self._dataset = dataset self._seed = sampling.config.seed - self._tokenizer = sampling.tokenizer + self._tokenizer = self._config.tokenizer.get_tokenizer() if self._tokenizer is None: raise ValueError("Fim requires a tokenizer") self._suffix_tok_id, self._prefix_tok_id, self._middle_tok_id, self._pad_tok_id = ( @@ -40,11 +41,15 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx: int) -> np.ndarray: - fim_token_ids = self._fim( - self._dataset[idx].token_ids, np.random.RandomState(seed=(self._seed + idx) % MAX_SEED) + def __getitem__(self, index: int) -> SampleType: + # TODO: Use torch methods to avoid back and forth. + return GPTSample( + torch.from_numpy( + self._fim( + self._dataset[index].token_ids.numpy(), np.random.RandomState(seed=(self._seed + index) % MAX_SEED) + ) + ) ) - return GPTSample(fim_token_ids) @property def name(self) -> str: @@ -55,6 +60,7 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: # TODO: permute segments in sample_list, before concatenating. sample_len = sample.shape[0] eod = self._tokenizer.eod + # TODO: Available through `tokens.lengths` segment_breaks = np.argwhere(sample == eod) # split sample by document if segment_breaks.shape != (0, 1): # then there is an EOD token in this example diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py deleted file mode 100644 index 896229772..000000000 --- a/fast_llm/data/dataset/gpt/indexed.py +++ /dev/null @@ -1,60 +0,0 @@ -import abc -import typing - -import numpy as np - -from fast_llm.data.dataset.gpt.config import GPTSamplingData -from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset - -if typing.TYPE_CHECKING: - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - -class GPTIndexedDataset(IndexedDataset): - @abc.abstractmethod - def get_document_sizes(self) -> np.ndarray: - """ - The size of each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - - @abc.abstractmethod - def get_document_size(self, index: int) -> int: - """ - The size of a document in the dataset. - """ - - def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - return GPTSampledIndexedDataset(self, sampling) - - -class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): - """ - A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. - """ - - _dataset: GPTIndexedDataset - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] - - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) - - -class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( - ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset -): - _datasets: list[GPTIndexedDataset] - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) - - def get_document_size(self, index: int) -> int: - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f39fd56f4..c78805380 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -3,15 +3,17 @@ import typing import numpy as np +import torch -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER +from fast_llm.data.sample.gpt import GPTSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div -class GPTMemmapDataset(GPTIndexedDataset): +class GPTMemmapDataset[SampleType: GPTSample](IndexedDataset[SampleType]): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing @@ -142,37 +144,34 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - def get( - self, - idx: int, - offset: int = 0, - length: int | None = None, - use_loss_masking_spans: bool = False, - use_preference_loss_spans: bool = False, - ) -> GPTSample: + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None + ) -> SampleType: + if end is None: + end = self.get_document_size(index) token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + count=end - begin, + offset=self._pointers[index] + begin * np.dtype(self._dtype).itemsize, ) sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] + if parameters is not None and parameters.use_loss_masking_spans: + assert self._spans is not None + sample_spans = self._spans[index] # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] + sample_spans = sample_spans[(sample_spans[:, 0] < begin + len(token_ids)) & (sample_spans[:, 1] >= begin)] # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], begin) - begin # offset + sample_spans[:, 1] = np.minimum(sample_spans[:, 1], begin + len(token_ids) - 1) - begin + sample_spans = torch.from_numpy(sample_spans) chosen_span = None rejected_span = None - if use_preference_loss_spans: + if parameters is not None and parameters.use_preference_loss_spans: if not self._has_preference_spans: raise ValueError("No preference spans found in memmap dataset.") elif self._has_preference_spans and self._chosen_spans is None: @@ -180,28 +179,30 @@ def get( elif self._has_preference_spans and self._rejected_spans is None: raise ValueError("Failed to read rejected spans from memmap dataset.") else: - chosen_span = self._chosen_spans[idx] + chosen_span = self._chosen_spans[index] # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] + chosen_span = chosen_span[(chosen_span[0] < begin + len(token_ids)) & (chosen_span[1] >= begin)][0] # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset + chosen_span[0] = np.maximum(chosen_span[0], begin) - begin # offset + chosen_span[1] = np.minimum(chosen_span[1], begin + len(token_ids) - 1) - begin + chosen_span = torch.from_numpy(chosen_span) - rejected_span = self._rejected_spans[idx] + rejected_span = self._rejected_spans[index] # filter spans that are outside the range of the selected tokens in the document rejected_span = rejected_span[ - (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) + (rejected_span[0] < begin + len(token_ids)) & (rejected_span[1] >= begin) ][0] # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + rejected_span[0] = np.maximum(rejected_span[0], begin) - begin # offset + rejected_span[1] = np.minimum(rejected_span[1], begin + len(token_ids) - 1) - begin + rejected_span = torch.from_numpy(rejected_span) return GPTSample( - token_ids=token_ids, + token_ids=torch.from_numpy(token_ids), loss_masking_spans=sample_spans, chosen_span=chosen_span, rejected_span=rejected_span, @@ -218,13 +219,13 @@ def __len__(self) -> int: def num_tokens(self) -> int: return self._num_tokens - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self) -> torch.Tensor: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + return torch.from_numpy(self._document_sizes) def get_document_size(self, index: int) -> int: return self._document_sizes[index].item() @@ -258,7 +259,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." # Write document to binary file - bin_stream.write(document.token_ids.tobytes(order="C")) + bin_stream.write(document.token_ids.numpy().tobytes(order="C")) # Update metadata doc_length = len(document.token_ids) @@ -271,7 +272,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP chosen_spans.append(document.chosen_span) if document.rejected_span is not None: rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize + offset += doc_length * dtype.itemsize num_documents += 1 # Finalize metadata arrays @@ -297,7 +298,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Flag to indicate whether preference loss-masking spans are present idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) # Data type - idx_stream.write(struct.pack(" str: return self._name -class GPTRandomSampledDataset(SampledDataset): +class GPTRandomSampledDataset[SampleType: GPTSample](SampledDataset[SampleType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed @@ -32,10 +33,12 @@ def __init__(self, sampling: GPTSamplingData, name: str): def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx) -> np.ndarray: + def __getitem__(self, index: int) -> SampleType: return GPTSample( - np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( - 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + torch.from_numpy( + np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( + 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + ) ) ) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 09ed52779..c6eac9e28 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -1,20 +1,37 @@ import abc -import typing -import numpy as np +import torch from fast_llm.data.dataset.abstract import SamplableDataset +from fast_llm.data.dataset.config import SamplingData, SamplingParameters +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, padded_cumsum -class IndexedDataset(SamplableDataset): +class IndexedDataset[SampleType: Sample](SamplableDataset[SampleType]): """ A dataset containing a list of samples. TODO: Move sampling responsibility here? """ @abc.abstractmethod - def get(self, index: int, *args, **kwargs) -> typing.Any: + def get_document_sizes(self) -> torch.Tensor: + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + + @abc.abstractmethod + def get_document_size(self, index: int) -> int: + """ + The size of a document in the dataset. + """ + + @abc.abstractmethod + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: pass @abc.abstractmethod @@ -23,13 +40,18 @@ def __len__(self) -> int: Number of samples in the dataset. """ + def sample(self, sampling: SamplingData) -> "GPTSampledIndexedDataset": + from fast_llm.data.dataset.sampled import SampledIndexedDataset + + return SampledIndexedDataset(self, sampling) -class DatasetSlice[IndexedDatasetType: IndexedDataset](IndexedDataset): + +class DatasetSlice[SampleType: Sample](IndexedDataset[SampleType]): def __init__( self, name: str, - dataset: IndexedDataset, + dataset: IndexedDataset[SampleType], begin: int | None = None, end: int | None = None, ): @@ -46,15 +68,22 @@ def __init__( except Exception as e: raise AssertionError(f"Invalid document indices for dataset {name} with length {num_samples}") from e - def get( - self, document: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - ) -> typing.Any: + def get_document_sizes(self) -> torch.Tensor: + # TODO: This can be really big. + return self._dataset.get_document_sizes()[self._begin : self._end] + + def get_document_size(self, index: int) -> int: + return self._dataset.get_document_size(self._begin + index) + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: """ Get the sample (document) with the given index (in the dataset slice), - optionally sub-sampled to a specific offset (starting point) and maximum length + optionally subsampled to a specific offset (starting point) and maximum length (end = min(offset + length, sample_length). """ - return self._dataset.get(document + self._begin, offset, length, use_loss_masking_spans) + return self._dataset.get_document(index + self._begin, begin, end, parameters) def __len__(self) -> int: return self._end - self._begin @@ -64,24 +93,36 @@ def name(self) -> str: return self._name -class ConcatenatedDataset[IndexedDatasetType: IndexedDataset](IndexedDataset): +class ConcatenatedDataset[SampleType: Sample](IndexedDataset[SampleType]): def __init__( self, name: str, - datasets: list[IndexedDataset], + datasets: list[IndexedDataset[SampleType]], ): self._name = name self._datasets = datasets sizes = [len(dataset) for dataset in self._datasets] - self._dataset_splits = padded_cumsum(sizes) + self._dataset_splits = torch.from_numpy(padded_cumsum(sizes)) def __len__(self) -> int: return self._dataset_splits[-1].item() - def get(self, index: int, *args, **kwargs): - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get(index - self._dataset_splits[dataset].item(), *args, **kwargs) + def get_document_sizes(self) -> torch.Tensor: + # TODO: This can be really big. + return torch.cat([dataset.get_document_sizes() for dataset in self._datasets]) + + def get_document_size(self, index: int) -> int: + dataset = torch.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: + dataset = torch.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_document( + index - self._dataset_splits[dataset].item(), begin, end, parameters + ) @property def name(self) -> str: diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 86bc080fe..01f3195e4 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -1,8 +1,8 @@ import logging import time -import typing from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.sample.abstract import Sample try: from fast_llm.csrc.data import build_blending_indices # noqa @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class DatasetMonitor(SampledDataset): +class DatasetMonitor[SampleType: Sample](SampledDataset[SampleType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -24,7 +24,7 @@ class DatasetMonitor(SampledDataset): def __init__( self, - dataset: SampledDataset, + dataset: SampledDataset[SampleType], data_sample_warn_time_ms: float, ): self._dataset = dataset @@ -33,19 +33,19 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: start_time = time.perf_counter() try: - sample = self._dataset[idx] + sample = self._dataset[index] sample_time = (time.perf_counter() - start_time) * 1000 if sample_time > self._data_sample_warn_time_ms: logger.warning( - f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" + f"Sample {index} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" ) return sample except Exception: - logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") + logger.error(f"Failed to get sample {index} from dataset {self._dataset.name}") raise @property diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/sampled.py similarity index 93% rename from fast_llm/data/dataset/gpt/sampled.py rename to fast_llm/data/dataset/sampled.py index 95006f18e..238e99bca 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -1,4 +1,3 @@ -import dataclasses import logging import math import pathlib @@ -11,7 +10,9 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.sample.abstract import Sample +from fast_llm.data.sample.gpt import GPTSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -26,15 +27,6 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class GPTSample: - token_ids: np.ndarray - loss_masking_spans: np.ndarray | None = None - chosen_span: np.ndarray | None = None - rejected_span: np.ndarray | None = None - sequence_lengths: np.ndarray | None = None - - class MemmapArray: """ An array with lazy loading in memmap mode. @@ -75,14 +67,15 @@ def _lazy_load(self): TOKEN_CUMSUM_RATE = 10 -class GPTSampledIndexedDataset(SampledDataset): +class SampledIndexedDataset[SampleType: Sample](SampledDataset[SampleType]): """ A sampled GPT dataset. """ def __init__( self, - indexed_dataset: GPTIndexedDataset, + indexed_dataset: IndexedDataset[SampleType], + # TODO: ====== Remove gpt-specific stuff ====== sampling: GPTSamplingData, ): assert isinstance(sampling, GPTSamplingData) @@ -133,7 +126,7 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + document_sizes = self._indexed_dataset.get_document_sizes().to(self._device) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() @@ -375,7 +368,7 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - def __len__(self) -> int: return self._parameters.num_samples - def __getitem__(self, index: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. @@ -391,12 +384,11 @@ def __getitem__(self, index: int) -> typing.Any: self._document_shuffling[index - self._unshuffled_documents].item() ] - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - use_preference_loss_spans=self._parameters.use_preference_loss_spans, + sample = self._indexed_dataset.get_document( + document_index.item(), + begin=0, + end=self._document_sizes[document_index].item(), + parameters=self._parameters, ) chosen_span_end = sample.chosen_span[1] + 1 @@ -412,7 +404,7 @@ def __getitem__(self, index: int) -> typing.Any: sample.token_ids = padding if not self._parameters.cross_document_attention: - sample.sequence_lengths = np.array(sequence_lengths) + sample.sequence_lengths = torch.tensor(sequence_lengths) return sample @@ -474,11 +466,11 @@ def __getitem__(self, index: int) -> typing.Any: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) - sample = self._indexed_dataset.get( + sample = self._indexed_dataset.get_document( document_index, - offset=token_start_index_in_document, - length=token_end_index_in_document - token_start_index_in_document, - use_loss_masking_spans=self._parameters.use_loss_masking_spans, + begin=token_start_index_in_document, + end=token_end_index_in_document, + parameters=self._parameters, ) token_ids.append(sample.token_ids) if self._parameters.use_loss_masking_spans: @@ -496,19 +488,23 @@ def __getitem__(self, index: int) -> typing.Any: token_count += document_size sequence_lengths = ( - np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) + torch.tensor([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) if not self._parameters.cross_document_attention else None ) token_ids = np.concatenate(token_ids, dtype=np.int64) loss_masking_spans = ( - (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) + torch.from_numpy(np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) if self._parameters.use_loss_masking_spans else None ) Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + return GPTSample( + token_ids=torch.from_numpy(token_ids), + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + ) @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 33c40bf8f..a8ff187ae 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -14,17 +14,17 @@ import transformers import yaml -from fast_llm.data.dataset.gpt.config import ( - GPTBlendedDatasetConfig, - GPTDatasetSliceConfig, - GPTIndexedDatasetConfig, - GPTMemmapDatasetConfig, - GPTSampledDatasetConfig, +from fast_llm.data.dataset.config import ( + BlendedDatasetConfig, + DatasetSliceConfig, + IndexedDatasetConfig, + SampledDatasetConfig, ) +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.sample.gpt import GPTSample from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -37,6 +37,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _data_type: DataType _text_column: str _loss_masking_spans_column: str | None + _sample_type: typing.ClassVar[type[GPTSample]] = GPTSample def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ @@ -144,8 +145,8 @@ def _document_generator(): if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), + torch.tensor(item["input_ids"], dtype=self._data_type.torch), + torch.tensor(item["token_spans"], dtype=torch.int32).reshape(-1, 2), ) elif ( "chosen_token_spans" in shard_dataset.column_names @@ -155,13 +156,13 @@ def _document_generator(): ): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( - token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), + token_ids=torch.tensor(item["input_ids"], dtype=self._data_type.torch), + chosen_span=torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), + rejected_span=torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + yield GPTSample(torch.tensor(item["input_ids"], dtype=self._data_type.torch)) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -376,7 +377,9 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa torch.distributed.destroy_process_group() @classmethod - def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_path: pathlib.Path) -> None: + def _save_dataset_config( + cls, dataset_config: IndexedDatasetConfig[_sample_type], output_path: pathlib.Path + ) -> None: logger.info(f"Saving config to {output_path}") yaml.safe_dump( dataset_config.to_dict(), @@ -384,10 +387,12 @@ def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_pa ) @classmethod - def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) -> GPTIndexedDatasetConfig: + def _blend_dataset_configs( + cls, dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]] + ) -> IndexedDatasetConfig[_sample_type]: if len(dataset_configs) == 1: return dataset_configs[0] - return GPTSampledDatasetConfig.from_dict( + return SampledDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": dataset_configs, @@ -397,8 +402,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path - ) -> dict[str, GPTSampledDatasetConfig]: + cls, + dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]], + splits: dict[str, int | float], + output_path: pathlib.Path, + ) -> dict[str, SampledDatasetConfig[_sample_type]]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) @@ -427,13 +435,13 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() + sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) if end_index > begin_index: datasets_in_split.append( - GPTDatasetSliceConfig.from_dict( + DatasetSliceConfig[cls._sample_type].from_dict( { "type": "slice", "dataset": dataset_configs[dataset_index], @@ -455,7 +463,7 @@ def _split_and_blend_dataset_configs( elif len(datasets_in_split) == 1: dataset_splits[split_name] = datasets_in_split[0] else: - dataset_splits[split_name] = GPTBlendedDatasetConfig.from_dict( + dataset_splits[split_name] = BlendedDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": datasets_in_split, diff --git a/fast_llm/data/sample/__init__.py b/fast_llm/data/sample/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py new file mode 100644 index 000000000..0c640b9b3 --- /dev/null +++ b/fast_llm/data/sample/abstract.py @@ -0,0 +1,10 @@ +import abc + + +class Sample(abc.ABC): + pass + + +class Batch(abc.ABC): + # TODO: Relate to `BatchConfig`? + pass diff --git a/fast_llm/data/sample/gpt.py b/fast_llm/data/sample/gpt.py new file mode 100644 index 000000000..4bf740462 --- /dev/null +++ b/fast_llm/data/sample/gpt.py @@ -0,0 +1,25 @@ +import dataclasses +import typing + +from fast_llm.data.sample.abstract import Batch, Sample + +if typing.TYPE_CHECKING: + import torch + + +@dataclasses.dataclass +class GPTSample(Sample): + token_ids: "torch.Tensor" + loss_masking_spans: "torch.Tensor | None" = None + chosen_span: "torch.Tensor | None" = None + rejected_span: "torch.Tensor | None" = None + sequence_lengths: "torch.Tensor | None" = None + + +@dataclasses.dataclass +class GPTBatch(Batch): + token_ids: "torch.Tensor" + loss_masking_spans: "list[torch.Tensor] | None" = None + sequence_lengths: "list[torch.Tensor] | None" = None + chosen_spans: "list[torch.Tensor] | None" = None + rejected_spans: "list[torch.Tensor] | None" = None diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index f4a2cfd6c..add121c50 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -23,8 +23,10 @@ class DataType(enum.StrEnum): int32 = "int32" int16 = "int16" int8 = "int8" - uint8 = "uint8" + uint64 = "uint64" + uint32 = "uint32" uint16 = "uint16" + uint8 = "uint8" @classmethod def _missing_(cls, dtype: str) -> "DataType": @@ -105,6 +107,9 @@ def _set_torch_dtype_map() -> None: DataType.int32: torch.int32, DataType.int16: torch.int16, DataType.int8: torch.int8, + DataType.uint64: torch.uint64, + DataType.uint32: torch.uint32, + DataType.uint16: torch.uint16, DataType.uint8: torch.uint8, } _TORCH_DTYPE_MAP_INV = {y: x for x, y in _TORCH_DTYPE_MAP.items()} @@ -127,8 +132,10 @@ def _set_numpy_dtype_map() -> None: DataType.int32: np.int32, DataType.int16: np.int16, DataType.int8: np.int8, - DataType.uint8: np.uint8, + DataType.uint64: np.uint64, + DataType.uint32: np.uint32, DataType.uint16: np.uint16, + DataType.uint8: np.uint8, } _NUMPY_DTYPE_MAP_INV = {y: x for x, y in _NUMPY_DTYPE_MAP.items()} @@ -151,6 +158,9 @@ def _set_triton_dtype_map() -> None: DataType.int32: tl.int32, DataType.int16: tl.int16, DataType.int8: tl.int8, + DataType.uint64: tl.uint64, + DataType.uint32: tl.uint32, + DataType.uint16: tl.uint16, DataType.uint8: tl.uint8, } diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 4f035e174..f8dfd4825 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -2,6 +2,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.data.config import TokenizerConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert @@ -63,6 +64,9 @@ def get_evaluator( class LmEvalEvaluatorConfig(EvaluatorConfig): _abstract: typing.ClassVar[bool] = False + tokenizer: TokenizerConfig = Field( + desc="Configuration for the tokenizer.", + ) cli_args: list[str] = Field( default_factory=lambda: [], desc="lm_eval CLI arguments, excluding those related to model, wandb, batch sizes, and device.", diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 14aed65c4..5bfb544ed 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -60,7 +60,7 @@ def setup( self._flm_wrapper = FastLLMLmEvalWrapper( model=self._hf_model, - tokenizer=self._data.tokenizer.tokenizer, + tokenizer=self._config.tokenizer.get_tokenizer(), truncation=self._config.truncation, logits_cache=self._config.logits_cache, add_bos_token=self._config.add_bos_token, diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 9215e6dc7..a76c3712e 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -5,7 +5,7 @@ import torch import transformers.modeling_outputs -from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.data.sample.gpt import GPTBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ecb..bd3c91a38 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,7 +3,7 @@ import torch -from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.data.sample.gpt import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType diff --git a/tests/data/common.py b/tests/data/common.py index d8cc6fff2..3ade0e9bf 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -8,17 +8,11 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import ( - GPTIndexedDatasetConfig, - GPTSampledDatasetConfig, - GPTSamplingConfig, - GPTSamplingData, - GPTSamplingParameters, - ShufflingType, -) -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.dataset.config import IndexedDatasetConfig, SampledDatasetConfig, SamplingParameters +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig, GPTSamplingData, GPTSamplingParameters, ShufflingType +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.dataset.sampled import SampledIndexedDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig @@ -34,7 +28,6 @@ def get_sampling_data( phase=PhaseType.training, sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, - tokenizer: Tokenizer | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, @@ -56,13 +49,12 @@ def get_sampling_data( cache_directory=cache_directory, distributed=distributed, dataset_name=phase.value, - tokenizer=tokenizer, ) -def get_dataset_config[T: GPTSampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: - dataset_config = GPTSampledDatasetConfig.from_dict(config) - Assert.custom(isinstance, dataset_config, cls) +def get_dataset_config[T: SampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: + dataset_config = SampledDatasetConfig.from_dict(config) + Assert.custom(isinstance, dataset_config, getattr(cls, "__origin__", cls)) return typing.cast(cls, dataset_config) @@ -115,7 +107,7 @@ def get_test_data_and_compare_samples( def compare_indexed_dataset( - dataset: GPTIndexedDataset, + dataset: IndexedDataset, length: int, num_tokens: int, expected_samples: dict[int, list[int]], @@ -125,26 +117,30 @@ def compare_indexed_dataset( sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + [len(dataset.get_document(i).token_ids) for i in range(min(len(dataset), 100))], + sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) + Assert.all_equal(dataset.get_document(i).token_ids, np.array(expected_sample, dtype=np.uint16)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): Assert.all_equal( - dataset.get(i, use_loss_masking_spans=True).loss_masking_spans, + dataset.get_document( + i, + parameters=GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_loss_masking_spans=True + ), + ).loss_masking_spans, np.array(loss_masking_spans[i], dtype=np.int32).reshape(-1, 2), ) def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal([sampled[i].token_ids for i in range(len(expected_samples))], expected_samples) + Assert.all_equal(torch.stack([sampled[i].token_ids for i in range(len(expected_samples))]), expected_samples) -def validate_indexed_dataset_sampling( - sampled: GPTSampledIndexedDataset, expected_samples: list[list[int]] | None = None -): +def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): """ Compare `GPTSampledIndexedDataset` sampling against a more basic approach """ @@ -165,7 +161,7 @@ def validate_indexed_dataset_sampling( ) seen_tokens = 0 for document_index in document_sampling: - document = sampled._indexed_dataset.get(document_index).token_ids + document = sampled._indexed_dataset.get_document(document_index).token_ids all_tokens[seen_tokens : seen_tokens + len(document)] = document[: num_tokens - seen_tokens] seen_tokens += len(document) @@ -176,7 +172,7 @@ def validate_indexed_dataset_sampling( all_tokens[index * sampled._parameters.sequence_length : (index + 1) * sampled._parameters.sequence_length + 1] for index in range(sampled._parameters.num_samples) ] - token_ids = [sampled[i].token_ids for i in range(len(sampled))] + token_ids = torch.stack([sampled[i].token_ids for i in range(len(sampled))]) Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: @@ -184,8 +180,8 @@ def validate_indexed_dataset_sampling( return token_ids -@config_class(dynamic_type={GPTSampledDatasetConfig: "mock_memmap"}) -class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "mock_memmap"}) +class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False num_documents: int | None = Field( default=None, @@ -199,15 +195,15 @@ class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): ) path: pathlib.Path = Field(default=".") - def build(self) -> "GPTIndexedDataset": - return MockGPTMemmapDataset(self) + def build(self) -> "IndexedDataset": + return MockMemmapDataset(self) @property def num_tokens(self) -> int: return self.num_documents * self.num_tokens_per_document -class MockGPTMemmapDataset(GPTIndexedDataset): +class MockMemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): def __init__(self, config: MockGPTMemmapDatasetConfig): self._config = config @@ -218,11 +214,13 @@ def name(self) -> str: def __len__(self) -> int: return self._config.num_documents - def get_document_sizes(self) -> np.ndarray: - return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) + def get_document_sizes(self) -> torch.Tensor: + return torch.full([self._config.num_documents], self._config.num_tokens_per_document, dtype=torch.int64) def get_document_size(self, index: int) -> int: return self._config.num_tokens_per_document - def get(self, index: int, *args, **kwargs) -> typing.Any: + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: raise NotImplementedError() diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index e64b47020..678bffa21 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -3,7 +3,8 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig +from fast_llm.data.dataset.config import BlendedDatasetConfig +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( compare_sampled_dataset, @@ -122,7 +123,7 @@ def test_gpt_blended(): ], "weights": [0.75, 0.25], }, - GPTBlendedDatasetConfig, + BlendedDatasetConfig[GPTSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) @@ -161,7 +162,7 @@ def test_gpt_blended_mixed(): ], "weights": [0.6, 0.4], }, - GPTBlendedDatasetConfig, + BlendedDatasetConfig[GPTSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 2c025cbaf..bb4905cb6 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,4 +1,5 @@ -from fast_llm.data.dataset.gpt.config import GPTConcatenatedDatasetConfig +from fast_llm.data.dataset.config import ConcatenatedDatasetConfig +from fast_llm.data.sample.gpt import GPTSample from tests.data.common import ( compare_indexed_dataset, compare_sampled_dataset, @@ -27,7 +28,7 @@ def test_gpt_concatenate(): get_test_dataset() dataset = get_dataset_config( {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, - GPTConcatenatedDatasetConfig, + ConcatenatedDatasetConfig[GPTSample], ).build() compare_indexed_dataset( dataset, diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index c9212d6e3..438c5e7e3 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -1,6 +1,4 @@ -from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.gpt.config import GPTFimSampledDatasetConfig -from fast_llm.data.tokenizer import Tokenizer from tests.data.common import ( compare_sampled_dataset, get_dataset_config, @@ -29,13 +27,13 @@ def test_gpt_fim(): sampling_config = get_sampling_data( 8, sequence_length=5, - tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})), vocab_size=49157, ) sampled = get_dataset_config( { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", @@ -55,6 +53,7 @@ def test_gpt_fim_data(): "training": { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", @@ -62,7 +61,6 @@ def test_gpt_fim_data(): "suffix_token": "z", } }, - "tokenizer": {"path": TOKENIZER_PATH}, }, 8, sequence_length=5, diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 17ba5de01..388726bfb 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -4,12 +4,14 @@ import numpy as np import pytest +import torch -from fast_llm.data.dataset.gpt.config import GPTIndexedDatasetConfig +from fast_llm.data.dataset.config import IndexedDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert from tests.data.common import MockGPTMemmapDatasetConfig # Noqa @@ -28,22 +30,25 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): - documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] + documents = [ + GPTSample(torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype))) + for _ in range(100) + ] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): assert np.array_equal( - dataset.get(i).token_ids, document.token_ids, equal_nan=True - ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." + dataset.get_document(i).token_ids, document.token_ids, equal_nan=True + ), f"Mismatch for document {i}: {document} != {dataset.get_document(i)}." @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_preference_dataset(dtype): def generate_valid_span(max_seq_length): span = np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False) - return np.sort(span) + return torch.from_numpy(np.sort(span)) vocab_size = 1000 max_seq_length = 8192 @@ -51,7 +56,7 @@ def generate_valid_span(max_seq_length): documents = [ GPTSample( - token_ids=np.random.randint(vocab_size, size=max_seq_length).astype(dtype), + token_ids=torch.from_numpy(np.random.randint(vocab_size, size=max_seq_length).astype(dtype)), chosen_span=generate_valid_span(max_seq_length=max_seq_length), rejected_span=generate_valid_span(max_seq_length=max_seq_length), ) @@ -62,18 +67,23 @@ def generate_valid_span(max_seq_length): GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): - dataset_item = dataset.get(i, use_preference_loss_spans=True) + dataset_item = dataset.get_document( + i, + parameters=GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True + ), + ) assert np.array_equal( dataset_item.token_ids, document.token_ids, equal_nan=True - ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." + ), f"Token ids mismatch for document {i}: {document} != {dataset.get_document(i)}." assert np.array_equal( dataset_item.chosen_span, document.chosen_span, equal_nan=True - ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get(i).chosen_span}." + ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get_document(i).chosen_span}." assert np.array_equal( dataset_item.rejected_span, document.rejected_span, equal_nan=True - ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get(i).rejected_span}." + ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get_document(i).rejected_span}." def test_load_metadata_from_hub(): @@ -126,7 +136,7 @@ def test_absent_metadata_local(): def test_split_dataset(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0], {"training": 3, "validation": 1}, @@ -154,8 +164,8 @@ def test_split_dataset(): def test_split_datasets_0(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 1, "validation": 1}, @@ -173,8 +183,8 @@ def test_split_datasets_0(): def test_split_datasets_1(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[GPTSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 6a2be3dcc..d7b3021fe 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -1,11 +1,10 @@ -import typing - import numpy as np import pytest +import torch -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, ShufflingType -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters, ShufflingType +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert from tests.data.common import ( get_dataset_config, @@ -62,24 +61,23 @@ def test_gpt_sampled_data(): ) -class SimpleGPTIndexedDataset(GPTIndexedDataset): +class SimpleGPTIndexedDataset[SampleType: GPTSample](IndexedDataset[SampleType]): # TODO: worth adding to the main codebase? def __init__(self, samples): self._samples = samples - def get(self, index: int, offset=0, length=None, use_loss_masking_spans: bool = False) -> typing.Any: - if length is None: - length = len(self._samples[index]) - assert not use_loss_masking_spans - return GPTSample( - token_ids=np.array(self._samples[index][offset : offset + length], dtype=np.int64), loss_masking_spans=None - ) + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None + ) -> SampleType: + if end is None: + end = len(self._samples[index]) + return GPTSample(token_ids=torch.tensor(self._samples[index][begin:end], dtype=torch.int64)) def __len__(self) -> int: return len(self._samples) - def get_document_sizes(self) -> np.ndarray: - return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) + def get_document_sizes(self) -> torch.Tensor: + return torch.tensor([self.get_document_size(index) for index in range(len(self))], dtype=torch.int64) def get_document_size(self, index: int) -> int: return len(self._samples[index]) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 1fc8df1eb..e83387a24 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,4 +1,5 @@ -from fast_llm.data.dataset.gpt.config import GPTDatasetSliceConfig +from fast_llm.data.dataset.config import DatasetSliceConfig +from fast_llm.data.sample.gpt import GPTSample from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -34,7 +35,7 @@ def test_gpt_slice(): # samples[9:18] dataset = get_dataset_config( {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, - GPTDatasetSliceConfig, + DatasetSliceConfig[GPTSample], ).build() compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 6aa541b8c..f057c037f 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -3,12 +3,15 @@ import numpy as np import pytest +import torch from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSampledDatasetConfig, GPTSamplingData +from fast_llm.data.dataset.config import SampledDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingData from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample, logger +from fast_llm.data.dataset.sampled import logger +from fast_llm.data.sample.gpt import GPTSample from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset @@ -79,7 +82,7 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare_results_for_all_models(distributed_testing_config) -@config_class(dynamic_type={GPTSampledDatasetConfig: "megatron"}) +@config_class(dynamic_type={SampledDatasetConfig: "megatron"}) class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): _abstract: typing.ClassVar[bool] = False path: str = Field( @@ -142,14 +145,14 @@ def __getitem__(self, idx: int) -> typing.Any: doc_f, offset_f = self._sample_idx[shuffled_idx] doc_l, offset_l = self._sample_idx[shuffled_idx + 1] sample_list = [ - self._indexed_dataset.get( + self._indexed_dataset.get_document( self._doc_idx[doc].item(), - offset=(doc == doc_f) * offset_f, - length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, + begin=(doc == doc_f) * offset_f, + end=offset_l + 1 if doc == doc_l else None, ) for doc in range(doc_f, doc_l + 1) ] - token_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) + token_ids = torch.cat([sample.token_ids for sample in sample_list]) Assert.eq(len(token_ids), self._sequence_length + 1) return GPTSample(token_ids=token_ids) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 680faa931..b43923f4d 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -2,10 +2,11 @@ import random import numpy as np +import torch import yaml from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.sample.gpt import GPTSample from tests.utils.global_variables import ( DATASET_PREFIX, MODEL_DATASET_PREFIX, @@ -46,14 +47,15 @@ def get_test_dataset( tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) samples = [ - GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) for document in texts + GPTSample(torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size)) + for document in texts ] if max_spans > 0: lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) for sample, span in zip(samples, spans): span = np.unique(span) - sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) + sample.loss_masking_spans = torch.from_numpy(span[: len(span) // 2 * 2].reshape(-1, 2)) GPTMemmapDataset.write_dataset(prefix, samples) yaml.safe_dump(