Skip to content

Commit 8122d6d

Browse files
Merge branch 'main' into nemotron3-tiny-tests
2 parents 4f9b1f8 + 004ee19 commit 8122d6d

File tree

11 files changed

+104
-74
lines changed

11 files changed

+104
-74
lines changed

tests/test_dpo_trainer.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers.utils import is_peft_available
2323

2424
from trl import DPOConfig, DPOTrainer
25-
from trl.trainer.dpo_trainer import DataCollatorForPreference
25+
from trl.trainer.dpo_trainer import DataCollatorForPreference, DataCollatorForVisionPreference
2626

2727
from .testing_utils import (
2828
TrlTestCase,
@@ -132,6 +132,38 @@ def test_with_pad_to_multiple_of(self):
132132
torch.testing.assert_close(result["input_ids"], expected_input_ids)
133133

134134

135+
class TestDataCollatorForVisionPreference(TrlTestCase):
136+
@pytest.mark.skipif(
137+
Version(transformers.__version__) < Version("5.3.0"),
138+
reason="mm_token_type_ids are returned by default since transformers-5.3.0 (see transformers#43972)",
139+
)
140+
@require_vision
141+
def test_mm_token_type_ids_shape(self):
142+
# Regression test: when the processor returns mm_token_type_ids (e.g. Qwen2.5-VL after
143+
# transformers#43972), the collator must concatenate it with zeros for the completion part
144+
# so that its shape matches input_ids. Without the fix this raises an IndexError in the model.
145+
from PIL import Image
146+
from transformers import AutoProcessor
147+
148+
processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration")
149+
collator = DataCollatorForVisionPreference(processor)
150+
image = Image.new("RGB", (16, 16))
151+
examples = [
152+
{
153+
"images": [image],
154+
"prompt": [{"role": "user", "content": "What is this?"}],
155+
"chosen": [{"role": "assistant", "content": "A red square."}],
156+
"rejected": [{"role": "assistant", "content": "A blue circle."}],
157+
}
158+
]
159+
output = collator(examples)
160+
assert "mm_token_type_ids" in output
161+
assert output["mm_token_type_ids"].shape == output["input_ids"].shape, (
162+
f"mm_token_type_ids shape {output['mm_token_type_ids'].shape} != "
163+
f"input_ids shape {output['input_ids'].shape}"
164+
)
165+
166+
135167
class TestDPOTrainer(TrlTestCase):
136168
@pytest.mark.parametrize(
137169
"model_id",

trl/experimental/bco/bco_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,9 @@ def __init__(
439439
if type(args) is TrainingArguments:
440440
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
441441

442+
if train_dataset is None:
443+
raise ValueError("`train_dataset` is required")
444+
442445
if not isinstance(model, str) and model is not None and ref_model is model:
443446
raise ValueError(
444447
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "

trl/experimental/cpo/cpo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def __init__(
145145
peft_config: dict | None = None,
146146
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
147147
):
148+
if train_dataset is None:
149+
raise ValueError("`train_dataset` is required")
150+
148151
if args.model_init_kwargs is None:
149152
model_init_kwargs = {}
150153
elif not isinstance(model, str):

trl/experimental/kto/kto_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,9 @@ def __init__(
331331
if type(args) is TrainingArguments:
332332
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
333333

334+
if train_dataset is None:
335+
raise ValueError("`train_dataset` is required")
336+
334337
if not isinstance(model, str) and ref_model is model:
335338
raise ValueError(
336339
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "

trl/experimental/online_dpo/online_dpo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ def __init__(
199199
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
200200
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
201201
) -> None:
202+
if train_dataset is None:
203+
raise ValueError("`train_dataset` is required")
204+
202205
if ref_model is model:
203206
raise ValueError(
204207
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "

trl/experimental/orpo/orpo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def __init__(
154154
peft_config: dict | None = None,
155155
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
156156
):
157+
if train_dataset is None:
158+
raise ValueError("`train_dataset` is required")
159+
157160
if args.model_init_kwargs is None:
158161
model_init_kwargs = {}
159162
elif not isinstance(model, str):

trl/experimental/ppo/ppo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,9 @@ def __init__(
371371
callbacks: list[TrainerCallback] | None = None,
372372
peft_config: "PeftConfig | None" = None,
373373
) -> None:
374+
if train_dataset is None:
375+
raise ValueError("`train_dataset` is required")
376+
374377
if ref_model is model:
375378
raise ValueError(
376379
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "

trl/experimental/prm/prm_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ def __init__(
169169
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
170170
peft_config: dict | None = None,
171171
):
172+
if train_dataset is None:
173+
raise ValueError("`train_dataset` is required")
174+
172175
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
173176
model = prepare_peft_model(model, peft_config, args)
174177

trl/trainer/dpo_trainer.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,12 +336,23 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
336336
rejected_type_ids = processed_rejecteds["token_type_ids"]
337337
completion_token_type_ids = torch.cat(tuple(pad([chosen_type_ids, rejected_type_ids], padding_value=0)))
338338
token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1)
339+
if "mm_token_type_ids" in processed_prompts: # special case for Qwen2.5-VL
340+
prompt_mm_token_type_ids = processed_prompts["mm_token_type_ids"]
341+
mm_token_type_ids = torch.cat((prompt_mm_token_type_ids, torch.zeros_like(completion_ids)), dim=1)
339342

340343
# Flush left to reduce padding
341-
if "token_type_ids" in processed_prompts:
344+
if "token_type_ids" in processed_prompts and "mm_token_type_ids" in processed_prompts:
345+
attention_mask, input_ids, completion_mask, token_type_ids, mm_token_type_ids = flush_left(
346+
attention_mask, input_ids, completion_mask, token_type_ids, mm_token_type_ids
347+
)
348+
elif "token_type_ids" in processed_prompts:
342349
attention_mask, input_ids, completion_mask, token_type_ids = flush_left(
343350
attention_mask, input_ids, completion_mask, token_type_ids
344351
)
352+
elif "mm_token_type_ids" in processed_prompts:
353+
attention_mask, input_ids, completion_mask, mm_token_type_ids = flush_left(
354+
attention_mask, input_ids, completion_mask, mm_token_type_ids
355+
)
345356
else:
346357
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
347358

@@ -352,6 +363,8 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
352363
output["completion_mask"] = completion_mask
353364
if "token_type_ids" in processed_prompts:
354365
output["token_type_ids"] = token_type_ids
366+
if "mm_token_type_ids" in processed_prompts:
367+
output["mm_token_type_ids"] = mm_token_type_ids
355368
return output
356369

357370

@@ -992,7 +1005,14 @@ def compute_ref_log_probs(self, inputs):
9921005
shift_completion_mask = completion_mask[..., 1:].contiguous()
9931006

9941007
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
995-
for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"):
1008+
for key in (
1009+
"pixel_values",
1010+
"pixel_attention_mask",
1011+
"image_grid_thw",
1012+
"image_sizes",
1013+
"token_type_ids",
1014+
"mm_token_type_ids",
1015+
):
9961016
if key in inputs:
9971017
model_kwargs[key] = inputs[key]
9981018

@@ -1113,7 +1133,14 @@ def _compute_loss(self, model, inputs, return_outputs):
11131133
input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask)
11141134

11151135
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
1116-
for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"):
1136+
for key in (
1137+
"pixel_values",
1138+
"pixel_attention_mask",
1139+
"image_grid_thw",
1140+
"image_sizes",
1141+
"token_type_ids",
1142+
"mm_token_type_ids",
1143+
):
11171144
if key in inputs:
11181145
model_kwargs[key] = inputs[key]
11191146

trl/trainer/grpo_trainer.py

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@
2525
from collections import defaultdict, deque
2626
from collections.abc import Callable
2727
from contextlib import nullcontext
28-
from functools import partial
2928
from pathlib import Path
3029
from typing import Any, Protocol
3130

32-
import datasets
3331
import numpy as np
3432
import pandas as pd
3533
import torch
@@ -42,7 +40,7 @@
4240
from packaging.version import Version
4341
from torch import nn
4442
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
45-
from torch.utils.data import DataLoader, Sampler
43+
from torch.utils.data import Sampler
4644
from transformers import (
4745
AutoModelForSequenceClassification,
4846
AutoProcessor,
@@ -55,8 +53,7 @@
5553
is_trackio_available,
5654
is_wandb_available,
5755
)
58-
from transformers.trainer_utils import seed_worker
59-
from transformers.utils import is_datasets_available, is_peft_available, is_rich_available
56+
from transformers.utils import is_peft_available, is_rich_available
6057

6158
from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response
6259
from ..data_utils import (
@@ -849,37 +846,15 @@ def _set_signature_columns_if_needed(self):
849846
# `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the
850847
# splitting internally.
851848
# Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line
852-
# modification. As a result, some parts of the method aren't relevant to GRPO, but we keep them to stay one line
853-
# apart from the super method, ensuring easier maintenance in the future.
849+
# modification.
854850
def get_train_dataloader(self):
855-
if self.train_dataset is None:
856-
raise ValueError("Trainer: training requires a train_dataset.")
857-
858-
train_dataset = self.train_dataset
859-
data_collator = self.data_collator
860-
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
861-
train_dataset = self._remove_unused_columns(train_dataset, description="training")
862-
else:
863-
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
864-
865-
dataloader_params = {
866-
"batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change
867-
"collate_fn": data_collator,
868-
"num_workers": self.args.dataloader_num_workers,
869-
"pin_memory": self.args.dataloader_pin_memory,
870-
"persistent_workers": self.args.dataloader_persistent_workers,
871-
}
872-
873-
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
874-
dataloader_params["sampler"] = self._get_train_sampler()
875-
dataloader_params["drop_last"] = self.args.dataloader_drop_last
876-
dataloader_params["worker_init_fn"] = partial(
877-
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
878-
)
879-
880-
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
881-
882-
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
851+
return self._get_dataloader(
852+
dataset=self.train_dataset,
853+
description="Training",
854+
batch_size=self._train_batch_size * self.args.steps_per_generation, # < this is the change
855+
sampler_fn=self._get_train_sampler,
856+
is_training=True,
857+
)
883858

884859
def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler:
885860
# Returns a sampler that

0 commit comments

Comments
 (0)