Skip to content

Commit b63ee11

Browse files
committed
fix trainer test
1 parent 292795a commit b63ee11

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/trainer/trainer_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_trainer(self):
6060
self.config.buffer.explorer_input.eval_tasksets.append(
6161
get_unittest_dataset_config("copy_countdown", "test")
6262
)
63-
self.config.trainer.save_interval = 6
63+
self.config.trainer.save_interval = 4
6464
self.config.check_and_update()
6565
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
6666
self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 2
@@ -84,17 +84,17 @@ def test_trainer(self):
8484
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
8585
ray.shutdown(_exiting_interpreter=True)
8686
# check checkpoint
87-
checkpoint_step_6, _ = get_checkpoint_dir_with_step_num(
87+
checkpoint_step_4, _ = get_checkpoint_dir_with_step_num(
8888
checkpoint_root_path=self.config.checkpoint_job_dir,
8989
trainer_type=self.config.trainer.trainer_type,
90-
step_num=6,
90+
step_num=4,
9191
)
9292
# check save lastest checkpoint
9393
checkpoint_step_8, step_num = get_checkpoint_dir_with_step_num(
9494
checkpoint_root_path=self.config.checkpoint_job_dir,
9595
trainer_type=self.config.trainer.trainer_type,
9696
)
97-
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_6, "actor"))) > 0)
97+
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_4, "actor"))) > 0)
9898
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0)
9999
self.assertEqual(step_num, 8)
100100
# TODO: Reinit will fail when using v1 engine, find a way to fix it

0 commit comments

Comments
 (0)