@@ -213,15 +213,18 @@ def handle_compatibility(self: Union['SftArguments', 'InferArguments']) -> None:
213213 v = _mapping [k ]
214214 setattr (self , _name , v )
215215 break
216- if isinstance (self .dataset , str ):
217- self .dataset = [self .dataset ]
218- if len (self .dataset ) == 1 and ',' in self .dataset [0 ]:
219- self .dataset = self .dataset [0 ].split (',' )
220- for i , dataset in enumerate (self .dataset ):
221- if dataset in dataset_name_mapping :
222- self .dataset [i ] = dataset_name_mapping [dataset ]
223- for d in self .dataset :
224- assert ',' not in d , f'dataset: { d } , please use `/`'
216+ for key in ['dataset' , 'val_dataset' ]:
217+ _dataset = getattr (self , key )
218+ if isinstance (_dataset , str ):
219+ _dataset = [_dataset ]
220+ if len (_dataset ) == 1 and ',' in _dataset [0 ]:
221+ _dataset = _dataset [0 ].split (',' )
222+ for i , d in enumerate (_dataset ):
223+ if d in dataset_name_mapping :
224+ _dataset [i ] = dataset_name_mapping [d ]
225+ for d in _dataset :
226+ assert ',' not in d , f'dataset: { d } , please use `/`'
227+ setattr (self , key , _dataset )
225228 if self .truncation_strategy == 'ignore' :
226229 self .truncation_strategy = 'delete'
227230 if self .safe_serialization is not None :
@@ -1072,12 +1075,12 @@ def __post_init__(self) -> None:
10721075 self .torch_dtype , _ , _ = self .select_dtype ()
10731076 self .prepare_template ()
10741077 if self .eval_human is None :
1075- if not len (self .dataset ) > 0 :
1078+ if len (self .dataset ) == 0 and len ( self . val_dataset ) == 0 :
10761079 self .eval_human = True
10771080 else :
10781081 self .eval_human = False
10791082 logger .info (f'Setting self.eval_human: { self .eval_human } ' )
1080- elif self .eval_human is False and not len (self .dataset ) > 0 :
1083+ elif self .eval_human is False and len (self .dataset ) == 0 and len ( self . val_dataset ) == 0 :
10811084 raise ValueError ('Please provide the dataset or set `--load_dataset_config true`.' )
10821085
10831086 # compatibility
0 commit comments