Skip to content

Commit ca3b000

Browse files
authored
Version 0.3.0 (#364)
1 parent 8864b23 commit ca3b000

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+218
-977
lines changed

examples/mistral.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ model:
6262
multi_stage:
6363
zero_stage: 2
6464
distributed:
65-
training_dtype: bf16
65+
compute_dtype: bf16
6666
seed: 984059
6767
run:
6868
experiment_dir: mistral_example

fast_llm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.2.0"
1+
__version__ = "0.3.0"

fast_llm/config.py

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -759,58 +759,32 @@ def from_dict(
759759
return cls._from_dict(default, strict)
760760

761761
@classmethod
762-
def from_flat_dict(
763-
cls,
764-
default: dict[str, typing.Any],
765-
strict: bool = True,
766-
) -> typing.Self:
767-
# TODO v0.3: Remove flat format
768-
return cls._from_dict(default, strict, True)
769-
770-
@classmethod
771-
def _from_dict(
772-
cls,
773-
default: dict[str, typing.Any],
774-
strict: bool = True,
775-
flat: bool = False,
776-
) -> typing.Self:
777-
# TODO v0.3: Remove flat format
762+
def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self:
778763
out_arg_dict = {"_from_dict_check": True}
779-
780-
# TODO v0.3: Remove backward compatibility fix
781-
if "__class__" in default:
782-
del default["__class__"]
783-
784764
try:
785765
actual_cls = cls.get_subclass(default.get("type"))
786766
except KeyError:
787767
# Try to postpone error to validation.
788768
actual_cls = cls
789769

790770
if actual_cls is not None and actual_cls is not cls:
791-
return actual_cls._from_dict(default, strict=strict, flat=flat)
771+
return actual_cls._from_dict(default, strict=strict)
792772

793773
# Do not validate yet in case the root class sets cross-dependencies in validation.
794774
with NoAutoValidate():
795775
for name, field in cls.fields():
796776
if not field.init or field._field_type != dataclasses._FIELD: # noqa
797777
continue
798-
if flat:
799-
if isinstance(field.type, type) and issubclass(field.type, Config):
800-
out_arg_dict[name] = field.type._from_dict(default, False, True)
801-
elif name in default:
802-
out_arg_dict[name] = default.pop(name)
803-
else:
804-
# Check for nested configs to instantiate.
805-
try:
806-
value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict)
807-
if value is not MISSING:
808-
out_arg_dict[name] = value
809-
except FieldTypeError as e:
810-
raise FieldTypeError(
811-
f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: "
812-
+ ", ".join(e.args)
813-
)
778+
# Check for nested configs to instantiate.
779+
try:
780+
value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict)
781+
if value is not MISSING:
782+
out_arg_dict[name] = value
783+
except FieldTypeError as e:
784+
raise FieldTypeError(
785+
f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: "
786+
+ ", ".join(e.args)
787+
)
814788
out = cls(**out_arg_dict) # noqa
815789
if strict and default:
816790
out._unknown_fields = default.copy()

fast_llm/data/data/gpt/config.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
11
import logging
2-
import typing
32

43
from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
54
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig
65
from fast_llm.data.data.config import DataConfig
7-
from fast_llm.data.dataset.gpt.config import (
8-
GPTLegacyConfig,
9-
GPTLegacyDatasetConfig,
10-
GPTSampledDatasetConfig,
11-
GPTSamplingConfig,
12-
)
13-
from fast_llm.engine.distributed.config import PhaseType
6+
from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig
147
from fast_llm.utils import Assert
158

169
logger = logging.getLogger(__name__)
1710

1811

1912
@config_class()
20-
class GPTDataConfig(DataConfig, GPTLegacyConfig):
13+
class GPTDataConfig(DataConfig):
2114
"""
2215
Configuration for the dataset(s), split and sampling.
2316
Currently hard-coded to a GPT dataset.
@@ -48,32 +41,3 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
4841
desc="Multiprocessing context. Do not touch.",
4942
hint=FieldHint.expert,
5043
)
51-
52-
def _validate(self) -> None:
53-
if not self.datasets:
54-
logger.warning(
55-
"Using the legacy dataset definition format." " Specify it through `data.datasets` instead."
56-
)
57-
self.datasets = {
58-
phase.value.lower(): GPTLegacyDatasetConfig.from_dict(self, strict=False)
59-
for phase in (PhaseType.training, PhaseType.validation, PhaseType.test)
60-
}
61-
super()._validate()
62-
63-
@classmethod
64-
def _from_dict(
65-
cls,
66-
default: dict[str, typing.Any],
67-
strict: bool = True,
68-
flat: bool = False,
69-
) -> typing.Self:
70-
# TODO v0.x: Remove backward compatibility.
71-
if "datasets" in default:
72-
for phase in PhaseType:
73-
if phase.value in default["datasets"]:
74-
rename = phase.value.lower()
75-
logger.warning(f"Renaming dataset {phase.value} to {rename}")
76-
assert rename not in default["datasets"]
77-
default["datasets"][rename] = default["datasets"].pop(phase.value)
78-
79-
return super()._from_dict(default, strict, flat)

fast_llm/data/dataset/config.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,6 @@ class BlendedDatasetConfig(SampledDatasetConfig):
204204
desc="The blending weight of each dataset.",
205205
hint=FieldHint.core,
206206
)
207-
legacy: bool = Field(
208-
default=False,
209-
desc="Use the legacy formulas for sub-dataset seeds and sample sizes.",
210-
hint=FieldHint.deprecated,
211-
)
212207

213208
def _validate(self) -> None:
214209
self.weights = normalize_probabilities(self.weights)
@@ -231,20 +226,10 @@ def build_and_sample(
231226
sampling,
232227
parameters=dataclasses.replace(
233228
sampling.parameters,
234-
num_samples=(
235-
math.ceil(
236-
weight
237-
* (
238-
sampling.parameters.num_samples
239-
+ 5 * (sampling.parameters.num_samples * (1 - weight)) ** 0.5
240-
)
241-
)
242-
if self.legacy
243-
else math.ceil(weight * sampling.parameters.num_samples) + 1
244-
),
229+
num_samples=math.ceil(weight * sampling.parameters.num_samples) + 1,
245230
),
246231
# TODO: Seed may not be unique for nested blended datasets.
247-
config=sampling.config.to_copy({"seed": sampling.config.seed + i * (0 if self.legacy else 697)}),
232+
config=sampling.config.to_copy({"seed": sampling.config.seed + i * 697}),
248233
),
249234
)
250235
for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True))

fast_llm/data/dataset/gpt/config.py

Lines changed: 2 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import dataclasses
22
import enum
3-
import json
43
import pathlib
54
import time
65
import typing
7-
import warnings
86

97
import yaml
108

@@ -22,8 +20,7 @@
2220
SamplingData,
2321
SamplingParameters,
2422
)
25-
from fast_llm.engine.distributed.config import PhaseType
26-
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
23+
from fast_llm.utils import Assert
2724

2825
if typing.TYPE_CHECKING:
2926
from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset
@@ -41,7 +38,6 @@ class ShufflingType(str, enum.Enum):
4138
skip_first_epoch = "skip_first_epoch"
4239
# Disable shuffling entirely.
4340
disabled = "disabled"
44-
legacy = "legacy"
4541

4642

4743
@config_class()
@@ -222,53 +218,14 @@ def _convert_paths(self, config):
222218
return config
223219

224220

225-
# Add user-friendly names for the configs.
226-
@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"})
227-
class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig):
228-
# TODO v0.3: Remove.
229-
_abstract: typing.ClassVar[bool] = False
230-
path: pathlib.Path = Field(
231-
default=None,
232-
desc="The path to a dataset directory.",
233-
hint=FieldHint.core,
234-
)
235-
236-
def _validate(self) -> None:
237-
warnings.warn("`concatenated_memmap` dataset is deprecated. Use `file` instead.", DeprecationWarning)
238-
super()._validate()
239-
240-
def build(self) -> "GPTConcatenatedDataset":
241-
242-
assert self.path.is_dir()
243-
index_path = self.path / "index.txt"
244-
245-
if index_path.is_file():
246-
prefixes = [self.path / line.strip() for line in index_path.open("r").readlines()]
247-
else:
248-
warnings.warn(
249-
f"The dataset path {self.path} points to a directory."
250-
" The dataset will be indexed automatically, which may be unsafe."
251-
" We recommend using an index file instead."
252-
)
253-
prefixes = [
254-
path.with_suffix("")
255-
for path in self.path.iterdir()
256-
if path.suffix == ".idx" and path.is_file() and path.with_suffix(".bin").is_file()
257-
]
258-
dataset_config = GPTConcatenatedDatasetConfig.from_dict(
259-
{"datasets": [{"type": "memmap", "path": prefix} for prefix in prefixes]}
260-
)
261-
return dataset_config.build()
262-
263-
264221
@config_class()
265222
class FimConfig(Config):
266223
"""
267224
Configuration for FIM.
268225
"""
269226

270227
rate: float = Field(
271-
# TODO: Use meaningful default now that fim is a wrapper? (bad for legacy config)
228+
# TODO: Use meaningful default now that fim is a wrapper?
272229
default=0.0,
273230
desc="FIM rate for each sample.",
274231
hint=FieldHint.core,
@@ -352,131 +309,6 @@ def build_and_sample(
352309
return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling)
353310

354311

355-
class LegacyDatasetSource(str, enum.Enum):
356-
"""
357-
An enum for the different ways to load datasets.
358-
"""
359-
360-
list = "list"
361-
file = "file"
362-
random = "random"
363-
364-
365-
def _validate_split(value: list[int]) -> list[int]:
366-
Assert.leq(len(value), 3)
367-
return value + [0] * (len(value) - 3)
368-
369-
370-
def _validate_path(value: str | list[str]) -> list[str]:
371-
return [value] if isinstance(value, str) else value
372-
373-
374-
@config_class()
375-
class GPTLegacyConfig(Config):
376-
split: list[float] = Field(
377-
default_factory=lambda: [969, 30, 1],
378-
desc="Split ratio for train, valid and test datasets.",
379-
hint=FieldHint.deprecated,
380-
valid=_validate_split,
381-
)
382-
format: LegacyDatasetSource = Field(
383-
default=LegacyDatasetSource.list,
384-
desc="Format for the dataset definition.",
385-
hint=FieldHint.deprecated,
386-
)
387-
path: list[str] = Field(
388-
default_factory=list,
389-
desc="Path or list of paths and weights.",
390-
hint=FieldHint.deprecated,
391-
valid=_validate_path,
392-
)
393-
fim: FimConfig = Field(
394-
desc="Configuration for Fill In the Middle (FIM).",
395-
hint=FieldHint.feature,
396-
)
397-
398-
399-
@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"})
400-
class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig):
401-
_abstract: typing.ClassVar[bool] = False
402-
403-
def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:
404-
405-
if self.format == LegacyDatasetSource.random:
406-
Assert.eq(len(self.path), 0)
407-
dataset_config = GPTRandomDatasetConfig()
408-
else:
409-
if self.format == LegacyDatasetSource.file:
410-
Assert.eq(len(self.path), 1)
411-
data_path = pathlib.Path(self.path[0])
412-
dataset_defs = json.load(data_path.open("r"))
413-
data_base_path = data_path.parent
414-
dataset_prefixes = [
415-
(data_base_path / dataset_def["prefix"]).resolve() for dataset_def in dataset_defs["datasets"]
416-
]
417-
dataset_weights = normalize_probabilities(
418-
[dataset_def["weight"] for dataset_def in dataset_defs["datasets"]]
419-
)
420-
elif self.format == LegacyDatasetSource.list:
421-
Assert.geq(len(self.path), 1)
422-
if len(self.path) == 1:
423-
dataset_prefixes, dataset_weights = [self.path[0].strip()], [1.0]
424-
else:
425-
Assert.custom(lambda x: x % 2 == 0, len(self.path))
426-
dataset_prefixes = [pathlib.Path(x.strip()).resolve() for x in self.path[1::2]]
427-
assert len(dataset_prefixes) == len(set(dataset_prefixes))
428-
dataset_weights = normalize_probabilities([float(x) for x in self.path[::2]])
429-
else:
430-
raise NotImplementedError(self.format)
431-
432-
phase_splits = padded_cumsum(normalize_probabilities(self.split))
433-
434-
phase_index = {
435-
PhaseType.training.value.lower(): 0,
436-
PhaseType.validation.value.lower(): 1,
437-
PhaseType.test.value.lower(): 2,
438-
}[sampling.dataset_name]
439-
440-
dataset_configs = [
441-
{
442-
"type": "slice",
443-
# TODO: this duplicates memmap datasets for each phase.
444-
"dataset": {"type": "memmap", "path": prefix},
445-
"begin": float(phase_splits[phase_index]),
446-
"end": float(phase_splits[phase_index + 1]),
447-
}
448-
for prefix in dataset_prefixes
449-
]
450-
dataset_config = (
451-
{
452-
"type": "blended",
453-
"name": "blended",
454-
"datasets": dataset_configs,
455-
"weights": dataset_weights,
456-
"legacy": True,
457-
}
458-
if len(dataset_configs) > 1
459-
else dataset_configs[0]
460-
)
461-
if self.fim.rate > 0:
462-
dataset_config = {
463-
"type": "fim",
464-
"dataset": dataset_config,
465-
**self.fim.to_dict(),
466-
}
467-
# Legacy sampling config
468-
dataset_config = {
469-
"type": "sampled",
470-
"dataset": dataset_config,
471-
"sampling": {
472-
"seed": sampling.distributed.config.seed,
473-
"shuffle": "legacy",
474-
},
475-
}
476-
477-
return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling)
478-
479-
480312
@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"})
481313
class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig):
482314
"""

0 commit comments

Comments
 (0)