Skip to content

Commit 5d87095

Browse files
Fix prompt-completion labeling with add_generation_prompt and warning (#4201)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 8265800 commit 5d87095

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

trl/trainer/sft_trainer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
937937
prompt_ids = processing_class.apply_chat_template(
938938
example["prompt"],
939939
tokenize=True,
940+
add_generation_prompt=True,
940941
tools=example.get("tools"),
941942
**example.get("chat_template_kwargs", {}),
942943
)
@@ -974,7 +975,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
974975
"token handling. Verify that the tokenizer is processing text consistently."
975976
)
976977

977-
# Create a completion mask
978+
# Create completion mask
978979
completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids))
979980
output["input_ids"] = prompt_completion_ids
980981
output["completion_mask"] = completion_mask
@@ -994,17 +995,17 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo
994995
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
995996
# even for single examples, while for LLMs it returns lists of ints.
996997
processed = {k: v[0] if isinstance(v[0], list) else v for k, v in processed.items()}
997-
if "assistant_masks" in processed and 1 not in processed["assistant_masks"]:
998-
raise RuntimeError(
999-
"You're using `assistant_only_loss=True`, but at least one example has no "
1000-
"assistant tokens. This usually means the tokenizer's chat template doesn't "
1001-
"generate assistant masks — it may be missing the `{% generation %}` keyword. Please "
1002-
"check the template and ensure it's correctly configured to support assistant "
1003-
"masking."
1004-
)
1005998
output = {k: processed[k] for k in ("input_ids", "assistant_masks") if k in processed}
1006999
else:
10071000
output = {"input_ids": processing_class(text=example[dataset_text_field])["input_ids"]}
1001+
1002+
if "assistant_masks" in output and 1 not in output["assistant_masks"]:
1003+
raise RuntimeError(
1004+
"You're using `assistant_only_loss=True`, but at least one example has no assistant "
1005+
"tokens. This usually means the tokenizer's chat template doesn't generate assistant "
1006+
"masks — it may be missing the `{% generation %}` keyword. Please check the template and "
1007+
"ensure it's correctly configured to support assistant masking."
1008+
)
10081009
return output
10091010

10101011
dataset = dataset.map(

0 commit comments

Comments
 (0)