Skip to content

Commit 9ff79a6

Browse files
authored
🔮 Fix unused precomputed ref log probs in DPO (huggingface#2431)
1 parent 9001a86 commit 9ff79a6

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

trl/trainer/dpo_trainer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
130130
pixel_values = [torch.tensor(example["pixel_values"]) for example in examples]
131131
if "pixel_attention_mask" in examples[0]:
132132
pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples]
133+
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
134+
ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples])
135+
ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples])
133136

134137
# Pad
135138
output = {}
@@ -145,6 +148,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
145148
output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0)
146149
if "image_sizes" in examples[0]:
147150
output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples])
151+
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
152+
output["ref_chosen_logps"] = ref_chosen_logps
153+
output["ref_rejected_logps"] = ref_rejected_logps
148154

149155
return output
150156

@@ -162,7 +168,7 @@ class DPOTrainer(Trainer):
162168
args (`DPOConfig`):
163169
The DPO config arguments to use for training.
164170
data_collator (`transformers.DataCollator`):
165-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
171+
The data collator to use for training. If None is specified, the default data collator (`PreferenceCollator`) will be used
166172
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
167173
train_dataset (`datasets.Dataset`):
168174
The dataset to use for training.
@@ -672,9 +678,16 @@ def _set_signature_columns_if_needed(self):
672678
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
673679
# By default, this method sets `self._signature_columns` to the model's expected inputs.
674680
# In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
675-
# Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override.
681+
# Instead, we set them to the columns expected by `PreferenceCollator`, hence the override.
676682
if self._signature_columns is None:
677-
self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids", "image_sizes"]
683+
self._signature_columns = [
684+
"prompt_input_ids",
685+
"chosen_input_ids",
686+
"rejected_input_ids",
687+
"image_sizes",
688+
"ref_chosen_logps",
689+
"ref_rejected_logps",
690+
]
678691

679692
def get_train_dataloader(self) -> DataLoader:
680693
"""

0 commit comments

Comments
 (0)