Skip to content

Commit db0d955

Browse files
Fix CI FutureWarning: rpo_alpha is deprecated (#5011)
1 parent fa06506 commit db0d955

File tree

1 file changed

+0
-41
lines changed

1 file changed

+0
-41
lines changed

tests/test_dpo_trainer.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -349,45 +349,6 @@ def test_wrong_loss_weights_length(self):
349349
loss_weights=[1.0, 0.5, 0.1], # Wrong length
350350
)
351351

352-
@pytest.mark.parametrize("rpo_alpha", [None, 0.5])
353-
def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha):
354-
training_args = DPOConfig(
355-
output_dir=self.tmp_dir,
356-
per_device_train_batch_size=2,
357-
max_steps=3,
358-
remove_unused_columns=False,
359-
gradient_accumulation_steps=4,
360-
learning_rate=9e-1,
361-
eval_strategy="steps",
362-
beta=0.1,
363-
precompute_ref_log_probs=True,
364-
rpo_alpha=rpo_alpha,
365-
report_to="none",
366-
)
367-
368-
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
369-
370-
trainer = DPOTrainer(
371-
model=self.model,
372-
ref_model=None,
373-
args=training_args,
374-
processing_class=self.tokenizer,
375-
train_dataset=dummy_dataset["train"],
376-
eval_dataset=dummy_dataset["test"],
377-
)
378-
379-
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
380-
381-
trainer.train()
382-
383-
assert trainer.state.log_history[-1]["train_loss"] is not None
384-
385-
# Check that the parameters have changed
386-
for n, param in previous_trainable_params.items():
387-
new_param = trainer.model.get_parameter(n)
388-
if param.sum() != 0: # ignore 0 biases
389-
assert not torch.equal(param, new_param)
390-
391352
def test_dpo_trainer_with_ref_model_is_model(self):
392353
training_args = DPOConfig(
393354
output_dir=self.tmp_dir,
@@ -914,7 +875,6 @@ def test_dpo_trainer_use_logits_to_keep(self):
914875
eval_strategy="steps",
915876
beta=0.1,
916877
use_logits_to_keep=True,
917-
rpo_alpha=0.5,
918878
report_to="none",
919879
)
920880

@@ -960,7 +920,6 @@ def test_dpo_trainer_use_logits_to_keep(self):
960920
output = trainer.concatenated_forward(model, batch)
961921
output2 = trainer2.concatenated_forward(model, batch)
962922

963-
np.testing.assert_allclose(output["nll_loss"].item(), output2["nll_loss"].item(), atol=1e-5)
964923
np.testing.assert_allclose(
965924
output["mean_chosen_logits"].item(), output2["mean_chosen_logits"].item(), atol=1e-5
966925
)

0 commit comments

Comments
 (0)