11import dataclasses
22import enum
3- import json
43import pathlib
54import time
65import typing
7- import warnings
86
97import yaml
108
2220 SamplingData ,
2321 SamplingParameters ,
2422)
25- from fast_llm .engine .distributed .config import PhaseType
26- from fast_llm .utils import Assert , normalize_probabilities , padded_cumsum
23+ from fast_llm .utils import Assert
2724
2825if typing .TYPE_CHECKING :
2926 from fast_llm .data .dataset .gpt .indexed import GPTConcatenatedDataset , GPTDatasetSlice , GPTIndexedDataset
@@ -41,7 +38,6 @@ class ShufflingType(str, enum.Enum):
4138 skip_first_epoch = "skip_first_epoch"
4239 # Disable shuffling entirely.
4340 disabled = "disabled"
44- legacy = "legacy"
4541
4642
4743@config_class ()
@@ -222,53 +218,14 @@ def _convert_paths(self, config):
222218 return config
223219
224220
225- # Add user-friendly names for the configs.
226- @config_class (dynamic_type = {GPTSampledDatasetConfig : "concatenated_memmap" })
227- class GPTConcatenatedMemmapConfig (GPTIndexedDatasetConfig ):
228- # TODO v0.3: Remove.
229- _abstract : typing .ClassVar [bool ] = False
230- path : pathlib .Path = Field (
231- default = None ,
232- desc = "The path to a dataset directory." ,
233- hint = FieldHint .core ,
234- )
235-
236- def _validate (self ) -> None :
237- warnings .warn ("`concatenated_memmap` dataset is deprecated. Use `file` instead." , DeprecationWarning )
238- super ()._validate ()
239-
240- def build (self ) -> "GPTConcatenatedDataset" :
241-
242- assert self .path .is_dir ()
243- index_path = self .path / "index.txt"
244-
245- if index_path .is_file ():
246- prefixes = [self .path / line .strip () for line in index_path .open ("r" ).readlines ()]
247- else :
248- warnings .warn (
249- f"The dataset path { self .path } points to a directory."
250- " The dataset will be indexed automatically, which may be unsafe."
251- " We recommend using an index file instead."
252- )
253- prefixes = [
254- path .with_suffix ("" )
255- for path in self .path .iterdir ()
256- if path .suffix == ".idx" and path .is_file () and path .with_suffix (".bin" ).is_file ()
257- ]
258- dataset_config = GPTConcatenatedDatasetConfig .from_dict (
259- {"datasets" : [{"type" : "memmap" , "path" : prefix } for prefix in prefixes ]}
260- )
261- return dataset_config .build ()
262-
263-
264221@config_class ()
265222class FimConfig (Config ):
266223 """
267224 Configuration for FIM.
268225 """
269226
270227 rate : float = Field (
271- # TODO: Use meaningful default now that fim is a wrapper? (bad for legacy config)
228+ # TODO: Use meaningful default now that fim is a wrapper?
272229 default = 0.0 ,
273230 desc = "FIM rate for each sample." ,
274231 hint = FieldHint .core ,
@@ -352,131 +309,6 @@ def build_and_sample(
352309 return GPTFimDataset (self , self .dataset .build_and_sample (sampling ), sampling )
353310
354311
355- class LegacyDatasetSource (str , enum .Enum ):
356- """
357- An enum for the different ways to load datasets.
358- """
359-
360- list = "list"
361- file = "file"
362- random = "random"
363-
364-
365- def _validate_split (value : list [int ]) -> list [int ]:
366- Assert .leq (len (value ), 3 )
367- return value + [0 ] * (len (value ) - 3 )
368-
369-
370- def _validate_path (value : str | list [str ]) -> list [str ]:
371- return [value ] if isinstance (value , str ) else value
372-
373-
374- @config_class ()
375- class GPTLegacyConfig (Config ):
376- split : list [float ] = Field (
377- default_factory = lambda : [969 , 30 , 1 ],
378- desc = "Split ratio for train, valid and test datasets." ,
379- hint = FieldHint .deprecated ,
380- valid = _validate_split ,
381- )
382- format : LegacyDatasetSource = Field (
383- default = LegacyDatasetSource .list ,
384- desc = "Format for the dataset definition." ,
385- hint = FieldHint .deprecated ,
386- )
387- path : list [str ] = Field (
388- default_factory = list ,
389- desc = "Path or list of paths and weights." ,
390- hint = FieldHint .deprecated ,
391- valid = _validate_path ,
392- )
393- fim : FimConfig = Field (
394- desc = "Configuration for Fill In the Middle (FIM)." ,
395- hint = FieldHint .feature ,
396- )
397-
398-
399- @config_class (dynamic_type = {GPTSampledDatasetConfig : "legacy" })
400- class GPTLegacyDatasetConfig (GPTSampledDatasetConfig , GPTLegacyConfig ):
401- _abstract : typing .ClassVar [bool ] = False
402-
403- def build_and_sample (self , sampling : GPTSamplingData ) -> SampledDataset :
404-
405- if self .format == LegacyDatasetSource .random :
406- Assert .eq (len (self .path ), 0 )
407- dataset_config = GPTRandomDatasetConfig ()
408- else :
409- if self .format == LegacyDatasetSource .file :
410- Assert .eq (len (self .path ), 1 )
411- data_path = pathlib .Path (self .path [0 ])
412- dataset_defs = json .load (data_path .open ("r" ))
413- data_base_path = data_path .parent
414- dataset_prefixes = [
415- (data_base_path / dataset_def ["prefix" ]).resolve () for dataset_def in dataset_defs ["datasets" ]
416- ]
417- dataset_weights = normalize_probabilities (
418- [dataset_def ["weight" ] for dataset_def in dataset_defs ["datasets" ]]
419- )
420- elif self .format == LegacyDatasetSource .list :
421- Assert .geq (len (self .path ), 1 )
422- if len (self .path ) == 1 :
423- dataset_prefixes , dataset_weights = [self .path [0 ].strip ()], [1.0 ]
424- else :
425- Assert .custom (lambda x : x % 2 == 0 , len (self .path ))
426- dataset_prefixes = [pathlib .Path (x .strip ()).resolve () for x in self .path [1 ::2 ]]
427- assert len (dataset_prefixes ) == len (set (dataset_prefixes ))
428- dataset_weights = normalize_probabilities ([float (x ) for x in self .path [::2 ]])
429- else :
430- raise NotImplementedError (self .format )
431-
432- phase_splits = padded_cumsum (normalize_probabilities (self .split ))
433-
434- phase_index = {
435- PhaseType .training .value .lower (): 0 ,
436- PhaseType .validation .value .lower (): 1 ,
437- PhaseType .test .value .lower (): 2 ,
438- }[sampling .dataset_name ]
439-
440- dataset_configs = [
441- {
442- "type" : "slice" ,
443- # TODO: this duplicates memmap datasets for each phase.
444- "dataset" : {"type" : "memmap" , "path" : prefix },
445- "begin" : float (phase_splits [phase_index ]),
446- "end" : float (phase_splits [phase_index + 1 ]),
447- }
448- for prefix in dataset_prefixes
449- ]
450- dataset_config = (
451- {
452- "type" : "blended" ,
453- "name" : "blended" ,
454- "datasets" : dataset_configs ,
455- "weights" : dataset_weights ,
456- "legacy" : True ,
457- }
458- if len (dataset_configs ) > 1
459- else dataset_configs [0 ]
460- )
461- if self .fim .rate > 0 :
462- dataset_config = {
463- "type" : "fim" ,
464- "dataset" : dataset_config ,
465- ** self .fim .to_dict (),
466- }
467- # Legacy sampling config
468- dataset_config = {
469- "type" : "sampled" ,
470- "dataset" : dataset_config ,
471- "sampling" : {
472- "seed" : sampling .distributed .config .seed ,
473- "shuffle" : "legacy" ,
474- },
475- }
476-
477- return GPTSampledDatasetConfig .from_dict (dataset_config ).build_and_sample (sampling )
478-
479-
480312@config_class (dynamic_type = {GPTSampledDatasetConfig : "test_slow" })
481313class GPTTestSlowDatasetConfig (GPTSampledDatasetConfig ):
482314 """
0 commit comments