Skip to content

Commit c330c82

Browse files
committed
fix in unittest
1 parent b4224d3 commit c330c82

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

tests/trainer/trainer_test.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def setUp(self):
648648
if multiprocessing.get_start_method(allow_none=True) != "spawn":
649649
multiprocessing.set_start_method("spawn", force=True)
650650
self.config = get_template_config()
651-
self.config.buffer.total_epochs = 2
651+
self.config.buffer.total_epochs = 1
652652
self.config.buffer.batch_size = 4
653653
self.config.model.model_path = get_model_path()
654654
self.config.explorer.rollout_model.engine_type = "vllm_async"
@@ -657,22 +657,15 @@ def setUp(self):
657657
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
658658
self.config.monitor.monitor_type = "tensorboard"
659659
self.config.checkpoint_root_dir = get_checkpoint_path()
660-
self.config.synchronizer.sync_interval = 2
661-
self.config.synchronizer.sync_method = SyncMethod.NCCL
660+
self.config.synchronizer.sync_interval = 1
661+
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
662662
self.config.explorer.eval_interval = 4
663-
664-
def test_trainer(self):
665-
"""Test the checkpoint saving."""
666663
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
667-
self.config.buffer.explorer_input.eval_tasksets.append(
668-
get_unittest_dataset_config("countdown", "test")
669-
)
670-
self.config.buffer.explorer_input.eval_tasksets.append(
671-
get_unittest_dataset_config("copy_countdown", "test")
672-
)
673664
self.config.trainer.save_interval = 4
674-
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
675665
self.config.check_and_update()
666+
667+
def test_trainer(self):
668+
"""Test the checkpoint saving."""
676669
_trainer_config = self.config.trainer.trainer_config
677670
if self.strategy == "megatron":
678671
_trainer_config.actor_rollout_ref.actor.strategy = "megatron"

0 commit comments

Comments
 (0)