@@ -456,3 +456,59 @@ def test_fully_async_mode(self, name, use_priority_queue):
456456 def tearDown (self ):
457457 checkpoint_path = get_checkpoint_path ()
458458 shutil .rmtree (os .path .join (checkpoint_path , "unittest" ))
459+
460+
461+ class TestTrainerMIX (BaseTrainerCase ):
462+ def test_trainer (self ):
463+ """Test MIX algorithm."""
464+ # gsm8k has 16 tasks, sft_for_gsm8k has 8 tasks
465+ # total 4 steps, each step: read 4 tasks from gsm8k, 16 tasks from sft_for_gsm8k
466+ self .config .algorithm .algorithm_type = "mix"
467+ self .config .algorithm .repeat_times = 4
468+ self .config .algorithm .sample_strategy = "mix"
469+ self .config .algorithm .sample_strategy_args = {"expert_data_ratio" : 0.5 } # rft=4*4 : sft=16
470+ self .config .algorithm .policy_loss_fn = "mix"
471+ self .config .buffer .batch_size = 4
472+ self .config .buffer .train_batch_size = 32
473+ self .config .buffer .total_epochs = 1
474+ self .config .buffer .explorer_input .taskset = get_unittest_dataset_config ("gsm8k" )
475+ self .config .synchronizer .sync_interval = 1
476+ self .config .trainer .save_interval = 1
477+ self .config .buffer .trainer_input .sft_warmup_dataset = get_unittest_dataset_config (
478+ "sft_for_gsm8k"
479+ )
480+ self .config .buffer .trainer_input .sft_warmup_dataset .total_epochs = 8 # test this works
481+ self .config .check_and_update ()
482+ self .config .buffer .trainer_input .experience_buffer .max_read_timeout = 20
483+ self .config .trainer .trainer_config .trainer .max_actor_ckpt_to_keep = 2
484+ both (self .config )
485+ parser = TensorBoardParser (os .path .join (self .config .monitor .cache_dir , "tensorboard" ))
486+
487+ # test rollout metrics
488+ rollout_metrics = parser .metric_list ("rollout" )
489+ self .assertTrue (len (rollout_metrics ) > 0 )
490+ self .assertEqual (parser .metric_max_step (rollout_metrics [0 ]), 4 )
491+ self .assertEqual (
492+ parser .metric_values ("rollout/experience_count" )[1 ], 16
493+ ) # 16 rft experiences
494+ # test actor metrics
495+ actor_metrics = parser .metric_list ("actor" )
496+ self .assertTrue (len (actor_metrics ) > 0 )
497+ expert_metrics = parser .metric_list ("actor/expert/" )
498+ self .assertEqual (parser .metric_max_step (expert_metrics [0 ]), 4 ) # SFT
499+ usual_metrics = parser .metric_list ("actor/usual/" )
500+ self .assertEqual (parser .metric_max_step (usual_metrics [0 ]), 4 ) # RFT
501+ response_metrics = parser .metric_list ("response_length" )
502+ self .assertTrue (len (response_metrics ) > 0 )
503+ self .assertEqual (parser .metric_min_step (response_metrics [0 ]), 1 )
504+ self .assertEqual (parser .metric_max_step (response_metrics [0 ]), 4 )
505+ # test save checkpoint at last step
506+ checkpoint_dir , step_num = get_checkpoint_dir_with_step_num (
507+ checkpoint_root_path = self .config .checkpoint_job_dir ,
508+ trainer_type = "verl" ,
509+ )
510+ self .assertEqual (step_num , 4 )
511+ self .assertTrue (len (os .listdir (os .path .join (checkpoint_dir , "actor" ))) > 0 )
512+
513+ def tearDown (self ):
514+ shutil .rmtree (self .config .checkpoint_job_dir )
0 commit comments