1313import torch .distributed as dist
1414import transformers
1515from datasets import Dataset as HfDataset
16+ from datasets import IterableDataset as HfIterableDataset
1617from datasets import concatenate_datasets
1718from packaging import version
1819from torch import dtype as Dtype
3435from .utils import is_lmdeploy_available , is_quant_model , is_vllm_available
3536
3637logger = get_logger ()
38+ DATASET_TYPE = Union [HfDataset , HfIterableDataset ]
3739
3840
3941def is_adapter (sft_type : str ) -> bool :
@@ -374,11 +376,14 @@ def _register_self_cognition(self: Union['SftArguments', 'InferArguments']) -> N
374376 'Representing the model name and model author in Chinese and English.' )
375377 setattr (self , k , v )
376378
377- def _handle_dataset_compat (self : Union ['SftArguments' , 'InferArguments' ], train_dataset : Optional [HfDataset ],
378- val_dataset : Optional [HfDataset ]) -> Tuple [Optional [HfDataset ], Optional [HfDataset ]]:
379+ def _handle_dataset_compat (
380+ self : Union ['SftArguments' , 'InferArguments' ], train_dataset : Optional [DATASET_TYPE ],
381+ val_dataset : Optional [DATASET_TYPE ]) -> Tuple [Optional [DATASET_TYPE ], Optional [DATASET_TYPE ]]:
379382 # compatibility. (Deprecated)
383+ streaming = getattr (self , 'streaming' , False )
380384 random_state = np .random .RandomState (self .dataset_seed )
381385 val_dataset_sample = self .val_dataset_sample
386+
382387 if train_dataset is not None and self .train_dataset_sample >= 0 :
383388 train_dataset_sample = min (self .train_dataset_sample , train_dataset .shape [0 ])
384389 if train_dataset .shape [0 ] > train_dataset_sample :
@@ -388,10 +393,13 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
388393 if val_dataset_sample is None :
389394 val_dataset_sample = max (int (train_dataset_sample * self .dataset_test_ratio ), 1 )
390395 if val_dataset is not None and val_dataset_sample is not None and val_dataset_sample >= 0 :
391- if val_dataset .shape [0 ] > val_dataset_sample :
396+ if not streaming and val_dataset .shape [0 ] > val_dataset_sample :
392397 logger .info (f'val_dataset_sample: { val_dataset_sample } ' )
393398 val_idxs = random_state .permutation (val_dataset_sample )
394399 val_dataset = val_dataset .select (val_idxs )
400+ elif streaming :
401+ val_dataset = val_dataset .shuffle (
402+ seed = self .dataset_seed , buffer_size = self .streaming_buffer_size ).take (val_dataset_sample )
395403
396404 if (train_dataset is None or not hasattr (self , 'train_dataset_mix_ratio' ) or self .train_dataset_mix_ratio <= 0
397405 or len (self .train_dataset_mix_ds ) == 0 ):
@@ -401,7 +409,11 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
401409 logger .info (f'train_dataset_mix_ds: { self .train_dataset_mix_ds } ' )
402410 logger .info (f'len(train_dataset): { len (train_dataset )} , mix_dataset_sample: { mix_dataset_sample } ' )
403411 mixed_dataset = get_dataset (
404- self .train_dataset_mix_ds , 0.0 , random_state , check_dataset_strategy = self .check_dataset_strategy )[0 ]
412+ self .train_dataset_mix_ds ,
413+ 0.0 ,
414+ random_state ,
415+ check_dataset_strategy = self .check_dataset_strategy ,
416+ streaming = streaming )[0 ]
405417 if len (mixed_dataset ) < mix_dataset_sample :
406418 logger .warn (f'The length of dataset used for mixin: { self .train_dataset_mix_ds } are '
407419 'lesser than the ratio required by the `train_dataset_mix_ratio` '
@@ -590,7 +602,10 @@ class SftArguments(ArgumentsBase):
590602 max_length : int = 2048 # -1: no limit
591603 truncation_strategy : Literal ['delete' , 'truncation_left' ] = 'delete'
592604 check_dataset_strategy : Literal ['none' , 'discard' , 'error' , 'warning' ] = 'none'
593-
605+ # streaming dataset
606+ streaming : bool = False
607+ streaming_val_size : int = 0
608+ streaming_buffer_size : int = 16384
594609 # Chinese name and English name
595610 model_name : List [str ] = field (default_factory = lambda : [None , None ], metadata = {'help' : "e.g. ['小黄', 'Xiao Huang']" })
596611 model_author : List [str ] = field (
@@ -1025,7 +1040,8 @@ def __post_init__(self) -> None:
10251040 if self .gradient_accumulation_steps is None :
10261041 self .gradient_accumulation_steps = math .ceil (16 / self .batch_size / self .world_size )
10271042 template_info = TEMPLATE_MAPPING [self .template_type ]
1028- if self .lazy_tokenize is None :
1043+ self ._handle_streaming_args ()
1044+ if self .lazy_tokenize is None and not self .streaming :
10291045 self .lazy_tokenize = template_info .get ('lazy_tokenize' , False )
10301046 logger .info (f'Setting args.lazy_tokenize: { self .lazy_tokenize } ' )
10311047 if self .dataloader_num_workers is None :
@@ -1095,6 +1111,9 @@ def _init_training_args(self) -> None:
10951111 else :
10961112 kwargs ['evaluation_strategy' ] = self .evaluation_strategy
10971113
1114+ if 'accelerator_config' in parameters :
1115+ kwargs ['accelerator_config' ] = {'dispatch_batches' : False }
1116+
10981117 training_args = Seq2SeqTrainingArguments (
10991118 output_dir = self .output_dir ,
11001119 logging_dir = self .logging_dir ,
@@ -1181,6 +1200,42 @@ def _handle_pai_compat(self) -> None:
11811200 self .add_output_dir_suffix = False
11821201 logger .info (f'Setting args.add_output_dir_suffix: { self .add_output_dir_suffix } ' )
11831202
1203+ def _handle_streaming_args (self ) -> None :
1204+ if not self .streaming :
1205+ return
1206+ if self .max_steps == - 1 :
1207+ raise ValueError ('Please specify `max_steps` in streaming mode.' )
1208+
1209+ if self .packing :
1210+ self .packing = False
1211+ logger .warning ('Packing is not supported for streaming dataset, set to False' )
1212+
1213+ if self .test_oom_error :
1214+ self .test_oom_error = False
1215+ logger .warning ('test_oom_error is not supported for streaming dataset, set to False' )
1216+
1217+ if self .lazy_tokenize :
1218+ self .lazy_tokenize = False
1219+ logger .info ('lazy_tokenize set to False in streaming dataset' )
1220+
1221+ if self .train_dataset_mix_ratio > 0 :
1222+ logger .warning ('train_dataset_mix_ratio is not supported for streaming dataset, set to 0' )
1223+ self .train_dataset_mix_ratio = 0
1224+
1225+ if self .dataset_test_ratio > 0 :
1226+ logger .info ('Set dataset_test_ratio to 0 in streaming mode.'
1227+ 'You can manually set val_dataset and val_dataset_sample.'
1228+ 'or set streaming_val_size instead to split from train dataset' )
1229+ self .dataset_test_ratio = 0
1230+
1231+ if self .train_dataset_sample > 0 :
1232+ logger .warning ('train_dataset_sample is not supported for streaming dataset, set to -1' )
1233+ self .train_dataset_sample = - 1
1234+
1235+ if self .dataloader_num_workers is None or self .dataloader_num_workers > 0 :
1236+ logger .info ('Set dataloader_num_workers to 0 in streaming mode' )
1237+ self .dataloader_num_workers = 0
1238+
11841239
11851240@dataclass
11861241class InferArguments (ArgumentsBase ):
0 commit comments