@@ -344,6 +344,10 @@ def test_trainer(self, mock_load):
344344
345345 # sft warmup stage
346346 sft_config = stage_configs [0 ]
347+ self .assertEqual (
348+ sft_config .synchronizer .sync_interval ,
349+ sft_config .trainer .save_interval ,
350+ )
347351 parser = TensorBoardParser (os .path .join (sft_config .monitor .cache_dir , "tensorboard" ))
348352 rollout_metrics = parser .metric_list ("rollout" )
349353 self .assertEqual (len (rollout_metrics ), 0 )
@@ -374,11 +378,15 @@ def test_trainer(self, mock_load):
374378 self .assertEqual (parser .metric_min_step (response_metrics [0 ]), 1 )
375379 self .assertEqual (parser .metric_max_step (response_metrics [0 ]), 4 )
376380 # test save checkpoint when sft finish
381+ for i in range (3 ):
382+ self .assertFalse (
383+ os .path .exists (os .path .join (sft_config .checkpoint_job_dir , f"global_step_{ i } " ))
384+ )
377385 self .assertEqual (
378386 get_checkpoint_dir_with_step_num (
379- checkpoint_root_path = sft_config .checkpoint_job_dir , trainer_type = "verl" , step_num = 2
387+ checkpoint_root_path = sft_config .checkpoint_job_dir , trainer_type = "verl" , step_num = 3
380388 )[1 ],
381- 2 ,
389+ 3 ,
382390 )
383391 # test save checkpoint at last step
384392 checkpoint_dir , step_num = get_checkpoint_dir_with_step_num (
@@ -749,7 +757,7 @@ def setUp(self):
749757 if multiprocessing .get_start_method (allow_none = True ) != "spawn" :
750758 multiprocessing .set_start_method ("spawn" , force = True )
751759 self .config = get_template_config ()
752- self .config .buffer .total_epochs = 1
760+ self .config .buffer .total_steps = 6
753761 self .config .buffer .batch_size = 4
754762 self .config .model .model_path = get_model_path ()
755763 self .config .explorer .rollout_model .engine_type = "vllm_async"
@@ -762,21 +770,20 @@ def setUp(self):
762770 self .config .synchronizer .sync_method = SyncMethod .CHECKPOINT
763771 self .config .explorer .eval_interval = 4
764772 self .config .buffer .explorer_input .taskset = get_unittest_dataset_config ("countdown" )
765- self .config .trainer .save_interval = 4
773+ self .config .trainer .save_interval = 2
766774 self .config .trainer .save_hf_checkpoint = "last"
767775 self .config .trainer .trainer_strategy = self .strategy
776+ self .config .trainer .max_checkpoints_to_keep = 2
768777 self .config .check_and_update ()
769778 self .process_list = []
770779
771- def test_trainer (self ):
780+ def test_trainer (self ): # noqa: C901
772781 """Test the checkpoint saving."""
773782 _trainer_config = self .config .trainer .trainer_config
774783 if self .strategy == "megatron" :
775784 _trainer_config .actor_rollout_ref .actor .megatron .tensor_model_parallel_size = 2
776785 _trainer_config .actor_rollout_ref .ref .megatron .tensor_model_parallel_size = 2
777786 _trainer_config .critic .megatron .tensor_model_parallel_size = 2
778- _trainer_config .trainer .max_actor_ckpt_to_keep = 2
779- _trainer_config .trainer .max_critic_ckpt_to_keep = 2
780787
781788 stop_event = multiprocessing .Event ()
782789 trainer_process = multiprocessing .Process (target = run_both , args = (self .config , stop_event ))
@@ -839,6 +846,10 @@ def test_trainer(self):
839846 # print(f"State dict check at {state_dict_iteration} iteration passed.") # for debug
840847
841848 if checkpoint_iteration > 0 :
849+ flag_file = os .path .join (
850+ default_local_dir , f"global_step_{ checkpoint_iteration } " , ".full_checkpoint"
851+ )
852+ self .assertTrue (os .path .exists (flag_file ))
842853 for sub_dir_name in ["critic" , "actor" ]:
843854 iteration_dir = os .path .join (
844855 default_local_dir , f"global_step_{ checkpoint_iteration } " , sub_dir_name
@@ -882,6 +893,28 @@ def test_trainer(self):
882893 # print(f"Checkpoint check at {checkpoint_iteration} iteration passed.") # for debug
883894 if not stop_event .is_set ():
884895 self .fail ("Training process failed to stop." )
896+ # check only full checkpoint dirs are kept
897+ for sync_step in [1 , 3 , 5 ]:
898+ state_dict_dir = os .path .join (default_local_dir , f"global_step_{ sync_step } " )
899+ self .assertFalse (
900+ os .path .exists (state_dict_dir ),
901+ f"Found unexpected state dict dir at step { sync_step } " ,
902+ )
903+ for checkpoint_step in [4 , 6 ]:
904+ checkpoint_dir = os .path .join (default_local_dir , f"global_step_{ checkpoint_step } " )
905+ self .assertTrue (
906+ os .path .exists (checkpoint_dir ),
907+ f"Missing expected checkpoint dir at step { checkpoint_step } " ,
908+ )
909+ actor_checkpoint_dir = os .path .join (checkpoint_dir , "actor" )
910+ self .assertTrue (os .path .exists (actor_checkpoint_dir ))
911+ # check step 2 should have no checkpoint
912+ checkpoint_dir = os .path .join (default_local_dir , "global_step_2" )
913+ self .assertTrue (os .path .exists (checkpoint_dir ))
914+ actor_checkpoint_dir = os .path .join (checkpoint_dir , "actor" )
915+ self .assertFalse (os .path .exists (actor_checkpoint_dir ))
916+ critic_checkpoint_dir = os .path .join (checkpoint_dir , "critic" )
917+ self .assertFalse (os .path .exists (critic_checkpoint_dir ))
885918 trainer_process .join (timeout = 10 )
886919 self .assertIn ("model.safetensors" , huggingface_dir_files )
887920
0 commit comments