@@ -140,14 +140,14 @@ def _get_merged_dataset(
140140 model_args : "ModelArguments" ,
141141 data_args : "DataArguments" ,
142142 training_args : "Seq2SeqTrainingArguments" ,
143- stage : Literal ["pt" , "sft" , "rm" , "ppo" , "kto" ],
143+ stage : Literal ["pt" , "sft" , "rm" , "ppo" , "kto" , "iopo" ],
144144) -> Optional [Union ["Dataset" , "IterableDataset" ]]:
145145 if dataset_names is None :
146146 return None
147147
148148 datasets = []
149149 for dataset_attr in get_dataset_list (dataset_names , data_args .dataset_dir ):
150- if (stage == "rm" and dataset_attr .ranking is False ) or (stage != "rm" and dataset_attr .ranking is True ):
150+ if (( stage == "rm" or stage == "iopo" ) and dataset_attr .ranking is False ) or (( stage != "rm" and stage != "iopo" ) and dataset_attr .ranking is True ):
151151 raise ValueError ("The dataset is not applicable in the current training stage." )
152152
153153 datasets .append (_load_single_dataset (dataset_attr , model_args , data_args , training_args ))
@@ -199,7 +199,7 @@ def get_dataset(
199199 model_args : "ModelArguments" ,
200200 data_args : "DataArguments" ,
201201 training_args : "Seq2SeqTrainingArguments" ,
202- stage : Literal ["pt" , "sft" , "rm" , "ppo" , "kto" ],
202+ stage : Literal ["pt" , "sft" , "rm" , "ppo" , "kto" , "iopo" ],
203203 tokenizer : "PreTrainedTokenizer" ,
204204 processor : Optional ["ProcessorMixin" ] = None ,
205205) -> "DatasetModule" :
0 commit comments