Skip to content

Commit f6f4265

Browse files
🧑‍🍳 Add precompute batch size argument in DPOTrainer for reference model (huggingface#2426)
* added precompute_batch * review-fixes * moving up * Update trl/trainer/dpo_config.py Co-authored-by: Quentin GallouĂ©dec <[email protected]> * Update trl/trainer/dpo_config.py * Update trl/trainer/dpo_config.py [ci skip] --------- Co-authored-by: Quentin GallouĂ©dec <[email protected]>
1 parent 148b592 commit f6f4265

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

‎tests/test_dpo_trainer.py‎

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,40 @@ def test_dpo_trainer_with_ref_model_is_model(self):
350350
train_dataset=dummy_dataset["train"],
351351
)
352352

353+
def test_precompute_ref_batch_size(self):
354+
with tempfile.TemporaryDirectory() as tmp_dir:
355+
training_args = DPOConfig(
356+
output_dir=tmp_dir,
357+
per_device_train_batch_size=2,
358+
precompute_ref_log_probs=True,
359+
precompute_ref_batch_size=4,
360+
report_to="none",
361+
)
362+
363+
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
364+
365+
trainer = DPOTrainer(
366+
model=self.model,
367+
ref_model=self.ref_model,
368+
args=training_args,
369+
processing_class=self.tokenizer,
370+
train_dataset=dummy_dataset["train"],
371+
eval_dataset=dummy_dataset["test"],
372+
)
373+
374+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
375+
376+
trainer.train()
377+
378+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
379+
380+
# check the params have changed
381+
for n, param in previous_trainable_params.items():
382+
new_param = trainer.model.get_parameter(n)
383+
# check the params have changed - ignore 0 biases
384+
if param.sum() != 0:
385+
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
386+
353387
@require_peft
354388
def test_dpo_trainer_without_providing_ref_model_with_lora(self):
355389
from peft import LoraConfig

‎trl/trainer/dpo_config.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ class DPOConfig(TrainingArguments):
9494
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
9595
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
9696
useful when training without the reference model to reduce the total GPU memory needed.
97+
precompute_ref_batch_size (`Optional[int]`, *optional*, defaults to `None`):
98+
Batch size to use when precomputing reference model log probabilities. This can be set higher than the
99+
training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for
100+
training and `per_device_eval_batch_size` for evaluation.
97101
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
98102
Number of processes to use for processing the dataset.
99103
model_init_kwargs (`Optional[dict[str, Any]]`, *optional*, defaults to `None`):
@@ -173,6 +177,7 @@ class DPOConfig(TrainingArguments):
173177
disable_dropout: bool = True
174178
generate_during_eval: bool = False
175179
precompute_ref_log_probs: bool = False
180+
precompute_ref_batch_size: Optional[int] = None
176181
dataset_num_proc: Optional[int] = None
177182
model_init_kwargs: Optional[dict[str, Any]] = None
178183
ref_model_init_kwargs: Optional[dict[str, Any]] = None

‎trl/trainer/dpo_trainer.py‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -684,8 +684,9 @@ def get_train_dataloader(self) -> DataLoader:
684684
"""
685685

686686
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
687+
batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size
687688
dataloader_params = {
688-
"batch_size": self.args.per_device_train_batch_size,
689+
"batch_size": batch_size,
689690
"collate_fn": self.data_collator,
690691
"num_workers": self.args.dataloader_num_workers,
691692
"pin_memory": self.args.dataloader_pin_memory,
@@ -737,8 +738,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
737738
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
738739

739740
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
741+
batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size
740742
dataloader_params = {
741-
"batch_size": self.args.per_device_eval_batch_size,
743+
"batch_size": batch_size,
742744
"collate_fn": self.data_collator,
743745
"num_workers": self.args.dataloader_num_workers,
744746
"pin_memory": self.args.dataloader_pin_memory,

0 commit comments

Comments
 (0)