Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ def _validate_element(cls, value, type_, name: str):
value = cls._validate_dict(value, type_, name)
elif origin is type:
value = cls._validate_type(value, type_, name)
elif issubclass(origin, Config):
# TODO: Validate arguments for config generics.
cls._validate_element_type(value, type_.__origin__, strict=False)
value.validate(_is_validating=True)
else:
raise FieldTypeError(f"Unsupported __origin__ `{origin}`")
elif not isinstance(type_, type):
Expand Down Expand Up @@ -806,17 +810,24 @@ def _from_dict_nested(cls, value, type_, strict: bool):
value = cls._from_dict_array(value, type_, strict)
elif issubclass(origin, dict):
value = cls._from_dict_dict(value, type_, strict)
elif issubclass(origin, Config):
value = cls._from_dict_config(value, type_, strict)
elif origin is type:
pass
else:
raise FieldTypeError(f"Unsupported __origin__ `{origin}`")
elif not isinstance(type_, type):
raise FieldTypeError(f"Not a type: {type_}.")
elif issubclass(type_, Config):
if value is MISSING:
value = {}
if isinstance(value, dict):
value = type_._from_dict(value, strict)
value = cls._from_dict_config(value, type_, strict)
return value

@classmethod
def _from_dict_config(cls, value, type_, strict: bool):
if value is MISSING:
value = {}
if isinstance(value, dict):
value = type_._from_dict(value, strict)
return value

@classmethod
Expand Down Expand Up @@ -938,6 +949,7 @@ def __init_subclass__(cls):
We need to postpone validation until the class has been processed by the dataclass wrapper.
"""
Assert.eq(cls.__name__, cls.__qualname__)
super().__init_subclass__()
for base_class in cls.__mro__:
if issubclass(base_class, Config) and base_class is not cls:
assert cls.__class_validated__, (
Expand Down Expand Up @@ -1006,6 +1018,7 @@ def __init__(self, config: ConfigType, *args, **kwargs):
def __init_subclass__(cls):
# Automatically set `config_class` based on the bound type.
# Make sure `ConfigType` is bound and respects class hierarchy.
super().__init_subclass__()
try:
config_class = None
for base in types.get_original_bases(cls):
Expand Down
11 changes: 10 additions & 1 deletion fast_llm/data/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import enum
import pathlib
import typing

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.tokenizer import Tokenizer


class MultiprocessingContext(str, enum.Enum):
# Fast but risk of segfaults due to interactions with triton
Expand All @@ -29,7 +33,7 @@ class TokenizerConfig(Config):
hint=FieldHint.deprecated,
valid=check_field(Assert.eq, TokenizerFromFile),
)
path: pathlib.Path | None = Field(
path: pathlib.Path = Field(
default=None,
desc="Path to the tokenizer file.",
hint=FieldHint.core,
Expand All @@ -39,3 +43,8 @@ class TokenizerConfig(Config):
desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.",
hint=FieldHint.core,
)

def get_tokenizer(self) -> "Tokenizer":
from fast_llm.data.tokenizer import Tokenizer

return Tokenizer(self)
3 changes: 2 additions & 1 deletion fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fast_llm.config import Configurable
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.sample.abstract import Batch
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.schedule.config import BatchConfig

Expand Down Expand Up @@ -47,5 +48,5 @@ def get_iterator(
num_workers: int,
prefetch_factor: int | None = None,
timeout: float = 60,
) -> typing.Iterator[typing.Any]:
) -> typing.Iterator[Batch]:
pass
12 changes: 5 additions & 7 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging

from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig
from fast_llm.data.config import MultiprocessingContext
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig
from fast_llm.data.dataset.config import SampledDatasetConfig
from fast_llm.data.dataset.gpt.config import GPTSamplingConfig
from fast_llm.data.sample.gpt import GPTSample
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)
Expand All @@ -19,12 +21,8 @@ class GPTDataConfig(DataConfig):

_abstract = False

tokenizer: TokenizerConfig = Field(
desc="Configuration for the tokenizer (for FIM).",
hint=FieldHint.feature,
)
# TODO: Review field. Move closer to phase definition in training config?
datasets: dict[str, GPTSampledDatasetConfig] = Field(
datasets: dict[str, SampledDatasetConfig[GPTSample]] = Field(
default_factory=dict,
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
Expand Down
44 changes: 9 additions & 35 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import dataclasses
import logging
import pathlib
import typing
import warnings
from functools import partial

import numpy as np
import torch
import torch.utils.data

Expand All @@ -14,43 +12,32 @@
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.tokenizer import Tokenizer
from fast_llm.data.sample.gpt import GPTBatch, GPTSample
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.schedule.config import BatchConfig
from fast_llm.models.gpt.config import GPTBatchConfig
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor] | None = None
sequence_lengths: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
sequence_lengths = None
stacked_chosen_spans = None
stacked_rejected_spans = None
if sampling_parameters.use_loss_masking_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
stacked_spans = [sample.loss_masking_spans for sample in batch]
if sampling_parameters.use_preference_loss_spans:
stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch]
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
stacked_chosen_spans = [sample.chosen_span for sample in batch]
stacked_rejected_spans = [sample.rejected_span for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
sequence_lengths = [sample.sequence_lengths for sample in batch]
return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
token_ids=torch.stack([sample.token_ids for sample in batch]),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
Expand All @@ -67,7 +54,6 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):

_datasets: dict[str, SampledDataset]
_sampling_parameters: dict[str, GPTSamplingParameters]
_tokenizer: Tokenizer | None
_is_setup: bool = False

def __init__(
Expand Down Expand Up @@ -108,49 +94,37 @@ def setup(
)

log_main_rank(f"Preparing dataset. This may take several minutes.")
self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer)

if self._cache_directory is None:
# TODO: Avoid this
warnings.warn(f"Using the dataset directory for the index cache.")

self._datasets = {}
for dataset_name, sampling_parameters in self._sampling_parameters.items():
if self._tokenizer is not None:
# NOTE: Some models like Qwen2-1.5B-Instruct
# have vocab_size bigger in model config than in tokenizer
# TODO: Still, is it too constraining?
Assert.geq(sampling_parameters.vocab_size, self._tokenizer.vocab_size)
if sampling_parameters.num_samples > 0:
sampling = GPTSamplingData(
config=self._config.sampling,
parameters=sampling_parameters,
cache_directory=self._cache_directory,
distributed=distributed,
dataset_name=dataset_name,
tokenizer=self._tokenizer,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True

@property
def tokenizer(self) -> Tokenizer:
assert self._is_setup
return self._tokenizer

def get_iterator(
self,
batch_config: BatchConfig,
batch_config: GPTBatchConfig,
dataset_name: str,
*,
consumed_samples: int,
num_workers: int,
prefetch_factor: int | None = None,
timeout: float = 60,
) -> typing.Iterator[typing.Any]:
) -> typing.Iterator[GPTBatch]:
assert self._is_setup

# Some dataset names may come from phases and are capitalized,
Expand Down
20 changes: 15 additions & 5 deletions fast_llm/data/dataset/abstract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import abc
import typing

from fast_llm.data.sample.abstract import Sample

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.config import SamplingData


class Dataset(abc.ABC):
class Dataset[SampleType: Sample](abc.ABC):
"""
A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature.
"""
Expand All @@ -17,24 +19,32 @@ def name(self) -> str:
A name for the dataset to facilitate identification and debugging.
"""

def __getstate__(self):
state = super().__getstate__()
# Pickling sometimes fails with bound `SampleType`.
# This is not needed at runtime, so we just drop it.
if "__orig_class__" in state:
del state["__orig_class__"]
return state


class SampledDataset(Dataset):
class SampledDataset[SampleType: Sample](Dataset[SampleType]):
"""
A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training.
(See the `Sampler` class below.)
"""

@abc.abstractmethod
def __getitem__(self, index: int) -> typing.Any:
def __getitem__(self, index: int) -> SampleType:
pass

@abc.abstractmethod
def __len__(self) -> int:
pass


class SamplableDataset(Dataset):
class SamplableDataset[SampleType: Sample](Dataset[SampleType]):

@abc.abstractmethod
def sample(self, config: "SamplingData") -> SampledDataset:
def sample(self, config: "SamplingData") -> SampledDataset[SampleType]:
pass
Loading