@@ -349,45 +349,6 @@ def test_wrong_loss_weights_length(self):
349349 loss_weights = [1.0 , 0.5 , 0.1 ], # Wrong length
350350 )
351351
352- @pytest .mark .parametrize ("rpo_alpha" , [None , 0.5 ])
353- def test_dpo_trainer_without_providing_ref_model (self , rpo_alpha ):
354- training_args = DPOConfig (
355- output_dir = self .tmp_dir ,
356- per_device_train_batch_size = 2 ,
357- max_steps = 3 ,
358- remove_unused_columns = False ,
359- gradient_accumulation_steps = 4 ,
360- learning_rate = 9e-1 ,
361- eval_strategy = "steps" ,
362- beta = 0.1 ,
363- precompute_ref_log_probs = True ,
364- rpo_alpha = rpo_alpha ,
365- report_to = "none" ,
366- )
367-
368- dummy_dataset = load_dataset ("trl-internal-testing/zen" , "standard_preference" )
369-
370- trainer = DPOTrainer (
371- model = self .model ,
372- ref_model = None ,
373- args = training_args ,
374- processing_class = self .tokenizer ,
375- train_dataset = dummy_dataset ["train" ],
376- eval_dataset = dummy_dataset ["test" ],
377- )
378-
379- previous_trainable_params = {n : param .clone () for n , param in trainer .model .named_parameters ()}
380-
381- trainer .train ()
382-
383- assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
384-
385- # Check that the parameters have changed
386- for n , param in previous_trainable_params .items ():
387- new_param = trainer .model .get_parameter (n )
388- if param .sum () != 0 : # ignore 0 biases
389- assert not torch .equal (param , new_param )
390-
391352 def test_dpo_trainer_with_ref_model_is_model (self ):
392353 training_args = DPOConfig (
393354 output_dir = self .tmp_dir ,
@@ -914,7 +875,6 @@ def test_dpo_trainer_use_logits_to_keep(self):
914875 eval_strategy = "steps" ,
915876 beta = 0.1 ,
916877 use_logits_to_keep = True ,
917- rpo_alpha = 0.5 ,
918878 report_to = "none" ,
919879 )
920880
@@ -960,7 +920,6 @@ def test_dpo_trainer_use_logits_to_keep(self):
960920 output = trainer .concatenated_forward (model , batch )
961921 output2 = trainer2 .concatenated_forward (model , batch )
962922
963- np .testing .assert_allclose (output ["nll_loss" ].item (), output2 ["nll_loss" ].item (), atol = 1e-5 )
964923 np .testing .assert_allclose (
965924 output ["mean_chosen_logits" ].item (), output2 ["mean_chosen_logits" ].item (), atol = 1e-5
966925 )
0 commit comments