2222from swift .tuners import Swift
2323from swift .utils import (add_version_to_work_dir , get_dist_setting , get_logger , get_pai_tensorboard_dir , is_dist ,
2424 is_local_master , is_mp , is_pai_training_job , use_torchacc )
25- from .dataset import DATASET_MAPPING , _dataset_name_exists , get_dataset , register_dataset_info_file , sample_dataset
25+ from .dataset import (DATASET_MAPPING , _dataset_name_exists , get_dataset , parse_dataset_name ,
26+ register_dataset_info_file , sample_dataset )
2627from .model import (MODEL_MAPPING , dtype_mapping , get_additional_saved_files , get_default_lora_target_modules ,
2728 get_default_template_type )
2829from .template import TEMPLATE_MAPPING
@@ -271,9 +272,18 @@ def handle_custom_dataset_info(self):
271272 def _handle_dataset_sample (self ):
272273 # compatibility. (Deprecated)
273274 # Avoid post-processing
274- if len (self .dataset ) == 1 and '#' not in self .dataset [0 ] and self .train_dataset_sample >= 0 :
275- self .dataset [0 ] = f'{ self .dataset [0 ]} #{ self .train_dataset_sample } '
276- self .train_dataset_sample = - 1
275+ if len (self .dataset ) != 1 or self .train_dataset_sample == - 1 :
276+ return
277+ _dataset = self .dataset [0 ]
278+ train_sample = parse_dataset_name (_dataset )[3 ]
279+ if train_sample is None :
280+ train_sample = self .train_dataset_sample
281+ elif self .train_dataset_sample < train_sample :
282+ train_sample = self .train_dataset_sample
283+ _dataset = _dataset [:_dataset .find ('#' )]
284+ _dataset = f'{ _dataset } #{ train_sample } '
285+ self .dataset [0 ] = _dataset
286+ self .train_dataset_sample = - 1
277287
278288 def _register_self_cognition (self : Union ['SftArguments' , 'InferArguments' ]) -> None :
279289
@@ -688,11 +698,9 @@ def _prepare_modules_to_save(self, modules_to_save) -> List[str]:
688698
689699 def __post_init__ (self ) -> None :
690700 self .handle_compatibility ()
691- self ._register_self_cognition ()
692701 if len (self .val_dataset ) > 0 :
693702 self .dataset_test_ratio = 0.0
694703 logger .info ('Using val_dataset, ignoring dataset_test_ratio' )
695- self ._handle_dataset_sample ()
696704 if is_pai_training_job ():
697705 self ._handle_pai_compat ()
698706 ds_config_folder = os .path .abspath (os .path .join (__file__ , '..' , '..' , 'ds_config' ))
@@ -707,6 +715,8 @@ def __post_init__(self) -> None:
707715 break
708716
709717 self .handle_path ()
718+ self ._handle_dataset_sample ()
719+ self ._register_self_cognition ()
710720 self .handle_custom_register ()
711721 self .handle_custom_dataset_info ()
712722 self .set_model_type ()
@@ -1059,7 +1069,6 @@ def __post_init__(self) -> None:
10591069 logger .warning (f'The checkpoint dir { self .ckpt_dir } passed in is invalid, please make sure'
10601070 'the dir contains a `configuration.json` file.' )
10611071 self .handle_compatibility ()
1062- self ._register_self_cognition ()
10631072 if len (self .val_dataset ) > 0 :
10641073 self .dataset_test_ratio = 0.0
10651074 logger .info ('Using val_dataset, ignoring dataset_test_ratio' )
@@ -1073,6 +1082,7 @@ def __post_init__(self) -> None:
10731082 else :
10741083 assert self .load_dataset_config is False , 'You need to first set `--load_args_from_ckpt_dir true`.'
10751084 self ._handle_dataset_sample ()
1085+ self ._register_self_cognition ()
10761086 self .handle_custom_register ()
10771087 self .handle_custom_dataset_info ()
10781088 self .set_model_type ()
0 commit comments