@@ -60,7 +60,7 @@ def test_trainer(self):
6060 self .config .buffer .explorer_input .eval_tasksets .append (
6161 get_unittest_dataset_config ("copy_countdown" , "test" )
6262 )
63- self .config .trainer .save_interval = 6
63+ self .config .trainer .save_interval = 4
6464 self .config .check_and_update ()
6565 self .config .trainer .trainer_config .trainer .max_actor_ckpt_to_keep = 2
6666 self .config .trainer .trainer_config .trainer .max_critic_ckpt_to_keep = 2
@@ -84,17 +84,17 @@ def test_trainer(self):
8484 self .assertEqual (parser .metric_max_step (response_metrics [0 ]), 8 )
8585 ray .shutdown (_exiting_interpreter = True )
8686 # check checkpoint
87- checkpoint_step_6 , _ = get_checkpoint_dir_with_step_num (
87+ checkpoint_step_4 , _ = get_checkpoint_dir_with_step_num (
8888 checkpoint_root_path = self .config .checkpoint_job_dir ,
8989 trainer_type = self .config .trainer .trainer_type ,
90- step_num = 6 ,
90+ step_num = 4 ,
9191 )
9292 # check save lastest checkpoint
9393 checkpoint_step_8 , step_num = get_checkpoint_dir_with_step_num (
9494 checkpoint_root_path = self .config .checkpoint_job_dir ,
9595 trainer_type = self .config .trainer .trainer_type ,
9696 )
97- self .assertTrue (len (os .listdir (os .path .join (checkpoint_step_6 , "actor" ))) > 0 )
97+ self .assertTrue (len (os .listdir (os .path .join (checkpoint_step_4 , "actor" ))) > 0 )
9898 self .assertTrue (len (os .listdir (os .path .join (checkpoint_step_8 , "actor" ))) > 0 )
9999 self .assertEqual (step_num , 8 )
100100 # TODO: Reinit will fail when using v1 engine, find a way to fix it
0 commit comments