diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 3d2009a398..f1e96d09e5 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1829,7 +1829,12 @@ def test_train_vlm_gemma_3n(self): ) @pytest.mark.parametrize( "dataset_config", - ["conversational_language_modeling", "conversational_prompt_completion", "standard_prompt_completion"], + [ + "conversational_language_modeling", + "conversational_prompt_completion", + "standard_language_modeling", # Regression test for #5334 + "standard_prompt_completion", + ], ) @require_vision def test_train_vlm_text_only_data(self, model_id, dataset_config): diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 983f93941b..08dbe52413 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -1167,7 +1167,10 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo processed = {k: v[0] if isinstance(v[0], list) else v for k, v in processed.items()} output = {k: processed[k] for k in ("input_ids", "assistant_masks") if k in processed} else: - output = {"input_ids": processing_class(text=example[dataset_text_field])["input_ids"]} + # Fix transformers inconsistency: for VLMs, processing_class returns lists of lists + # even for single examples, while for LLMs it returns lists of ints. + ids = processing_class(text=example[dataset_text_field])["input_ids"] + output = {"input_ids": ids[0] if isinstance(ids[0], list) else ids} if "assistant_masks" in output and 1 not in output["assistant_masks"]: raise RuntimeError(