Skip to content

Commit 8dd65d5

Browse files
authored
Merge pull request #179 from AIRobotZhang/patch-1
Update loader.py
2 parents e0cc87a + 67e9698 commit 8dd65d5

File tree

1 file changed

+3
-3
lines changed
  • IOPO/Method-IOPO/src/llamafactory/data

1 file changed

+3
-3
lines changed

IOPO/Method-IOPO/src/llamafactory/data/loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@ def _get_merged_dataset(
140140
model_args: "ModelArguments",
141141
data_args: "DataArguments",
142142
training_args: "Seq2SeqTrainingArguments",
143-
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
143+
stage: Literal["pt", "sft", "rm", "ppo", "kto", "iopo"],
144144
) -> Optional[Union["Dataset", "IterableDataset"]]:
145145
if dataset_names is None:
146146
return None
147147

148148
datasets = []
149149
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
150-
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
150+
if ((stage == "rm" or stage == "iopo") and dataset_attr.ranking is False) or ((stage != "rm" and stage != "iopo") and dataset_attr.ranking is True):
151151
raise ValueError("The dataset is not applicable in the current training stage.")
152152

153153
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
@@ -199,7 +199,7 @@ def get_dataset(
199199
model_args: "ModelArguments",
200200
data_args: "DataArguments",
201201
training_args: "Seq2SeqTrainingArguments",
202-
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
202+
stage: Literal["pt", "sft", "rm", "ppo", "kto", "iopo"],
203203
tokenizer: "PreTrainedTokenizer",
204204
processor: Optional["ProcessorMixin"] = None,
205205
) -> "DatasetModule":

0 commit comments

Comments
 (0)