@@ -648,7 +648,7 @@ def setUp(self):
648648 if multiprocessing .get_start_method (allow_none = True ) != "spawn" :
649649 multiprocessing .set_start_method ("spawn" , force = True )
650650 self .config = get_template_config ()
651- self .config .buffer .total_epochs = 2
651+ self .config .buffer .total_epochs = 1
652652 self .config .buffer .batch_size = 4
653653 self .config .model .model_path = get_model_path ()
654654 self .config .explorer .rollout_model .engine_type = "vllm_async"
@@ -657,22 +657,15 @@ def setUp(self):
657657 self .config .name = f"trainer-{ datetime .now ().strftime ('%Y%m%d%H%M%S' )} "
658658 self .config .monitor .monitor_type = "tensorboard"
659659 self .config .checkpoint_root_dir = get_checkpoint_path ()
660- self .config .synchronizer .sync_interval = 2
661- self .config .synchronizer .sync_method = SyncMethod .NCCL
660+ self .config .synchronizer .sync_interval = 1
661+ self .config .synchronizer .sync_method = SyncMethod .CHECKPOINT
662662 self .config .explorer .eval_interval = 4
663-
664- def test_trainer (self ):
665- """Test the checkpoint saving."""
666663 self .config .buffer .explorer_input .taskset = get_unittest_dataset_config ("countdown" )
667- self .config .buffer .explorer_input .eval_tasksets .append (
668- get_unittest_dataset_config ("countdown" , "test" )
669- )
670- self .config .buffer .explorer_input .eval_tasksets .append (
671- get_unittest_dataset_config ("copy_countdown" , "test" )
672- )
673664 self .config .trainer .save_interval = 4
674- self .config .synchronizer .sync_method = SyncMethod .CHECKPOINT
675665 self .config .check_and_update ()
666+
667+ def test_trainer (self ):
668+ """Test the checkpoint saving."""
676669 _trainer_config = self .config .trainer .trainer_config
677670 if self .strategy == "megatron" :
678671 _trainer_config .actor_rollout_ref .actor .strategy = "megatron"
0 commit comments