@@ -749,7 +749,7 @@ def setUp(self):
749749 if multiprocessing .get_start_method (allow_none = True ) != "spawn" :
750750 multiprocessing .set_start_method ("spawn" , force = True )
751751 self .config = get_template_config ()
752- self .config .buffer .total_epochs = 1
752+ self .config .buffer .total_steps = 6
753753 self .config .buffer .batch_size = 4
754754 self .config .model .model_path = get_model_path ()
755755 self .config .explorer .rollout_model .engine_type = "vllm_async"
@@ -762,9 +762,10 @@ def setUp(self):
762762 self .config .synchronizer .sync_method = SyncMethod .CHECKPOINT
763763 self .config .explorer .eval_interval = 4
764764 self .config .buffer .explorer_input .taskset = get_unittest_dataset_config ("countdown" )
765- self .config .trainer .save_interval = 4
765+ self .config .trainer .save_interval = 2
766766 self .config .trainer .save_hf_checkpoint = "last"
767767 self .config .trainer .trainer_strategy = self .strategy
768+ self .config .trainer .max_checkpoints_to_keep = 2
768769 self .config .check_and_update ()
769770 self .process_list = []
770771
@@ -775,8 +776,6 @@ def test_trainer(self): # noqa: C901
775776 _trainer_config .actor_rollout_ref .actor .megatron .tensor_model_parallel_size = 2
776777 _trainer_config .actor_rollout_ref .ref .megatron .tensor_model_parallel_size = 2
777778 _trainer_config .critic .megatron .tensor_model_parallel_size = 2
778- _trainer_config .trainer .max_actor_ckpt_to_keep = 2
779- _trainer_config .trainer .max_critic_ckpt_to_keep = 2
780779
781780 stop_event = multiprocessing .Event ()
782781 trainer_process = multiprocessing .Process (target = run_both , args = (self .config , stop_event ))
@@ -887,10 +886,27 @@ def test_trainer(self): # noqa: C901
887886 if not stop_event .is_set ():
888887 self .fail ("Training process failed to stop." )
889888 # check only full checkpoint dirs are kept
890- for sync_step in [0 , 1 , 2 , 3 ]:
889+ for sync_step in [1 , 3 , 5 ]:
891890 state_dict_dir = os .path .join (default_local_dir , f"global_step_{ sync_step } " )
892- self .assertFalse (os .path .exists (state_dict_dir ))
893- self .assertTrue (os .path .exists (os .path .join (default_local_dir , "global_step_4" )))
891+ self .assertFalse (
892+ os .path .exists (state_dict_dir ),
893+ f"Found unexpected state dict dir at step { sync_step } " ,
894+ )
895+ for checkpoint_step in [4 , 6 ]:
896+ checkpoint_dir = os .path .join (default_local_dir , f"global_step_{ checkpoint_step } " )
897+ self .assertTrue (
898+ os .path .exists (checkpoint_dir ),
899+ f"Missing expected checkpoint dir at step { checkpoint_step } " ,
900+ )
901+ actor_checkpoint_dir = os .path .join (checkpoint_dir , "actor" )
902+ self .assertTrue (os .path .exists (actor_checkpoint_dir ))
903+ # check step 2 should have no checkpoint
904+ checkpoint_dir = os .path .join (default_local_dir , "global_step_2" )
905+ self .assertTrue (os .path .exists (checkpoint_dir ))
906+ actor_checkpoint_dir = os .path .join (checkpoint_dir , "actor" )
907+ self .assertFalse (os .path .exists (actor_checkpoint_dir ))
908+ critic_checkpoint_dir = os .path .join (checkpoint_dir , "critic" )
909+ self .assertFalse (os .path .exists (critic_checkpoint_dir ))
894910 trainer_process .join (timeout = 10 )
895911 self .assertIn ("model.safetensors" , huggingface_dir_files )
896912
0 commit comments