Skip to content

Commit 5424923

Browse files
committed
fix tests
1 parent 4edd89c commit 5424923

File tree

5 files changed

+37
-8
lines changed

5 files changed

+37
-8
lines changed

tests/trainer/trainer_test.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def setUp(self):
749749
if multiprocessing.get_start_method(allow_none=True) != "spawn":
750750
multiprocessing.set_start_method("spawn", force=True)
751751
self.config = get_template_config()
752-
self.config.buffer.total_epochs = 1
752+
self.config.buffer.total_steps = 6
753753
self.config.buffer.batch_size = 4
754754
self.config.model.model_path = get_model_path()
755755
self.config.explorer.rollout_model.engine_type = "vllm_async"
@@ -762,9 +762,10 @@ def setUp(self):
762762
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
763763
self.config.explorer.eval_interval = 4
764764
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
765-
self.config.trainer.save_interval = 4
765+
self.config.trainer.save_interval = 2
766766
self.config.trainer.save_hf_checkpoint = "last"
767767
self.config.trainer.trainer_strategy = self.strategy
768+
self.config.trainer.max_checkpoints_to_keep = 2
768769
self.config.check_and_update()
769770
self.process_list = []
770771

@@ -775,8 +776,6 @@ def test_trainer(self): # noqa: C901
775776
_trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2
776777
_trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2
777778
_trainer_config.critic.megatron.tensor_model_parallel_size = 2
778-
_trainer_config.trainer.max_actor_ckpt_to_keep = 2
779-
_trainer_config.trainer.max_critic_ckpt_to_keep = 2
780779

781780
stop_event = multiprocessing.Event()
782781
trainer_process = multiprocessing.Process(target=run_both, args=(self.config, stop_event))
@@ -887,10 +886,27 @@ def test_trainer(self): # noqa: C901
887886
if not stop_event.is_set():
888887
self.fail("Training process failed to stop.")
889888
# check only full checkpoint dirs are kept
890-
for sync_step in [0, 1, 2, 3]:
889+
for sync_step in [1, 3, 5]:
891890
state_dict_dir = os.path.join(default_local_dir, f"global_step_{sync_step}")
892-
self.assertFalse(os.path.exists(state_dict_dir))
893-
self.assertTrue(os.path.exists(os.path.join(default_local_dir, "global_step_4")))
891+
self.assertFalse(
892+
os.path.exists(state_dict_dir),
893+
f"Found unexpected state dict dir at step {sync_step}",
894+
)
895+
for checkpoint_step in [4, 6]:
896+
checkpoint_dir = os.path.join(default_local_dir, f"global_step_{checkpoint_step}")
897+
self.assertTrue(
898+
os.path.exists(checkpoint_dir),
899+
f"Missing expected checkpoint dir at step {checkpoint_step}",
900+
)
901+
actor_checkpoint_dir = os.path.join(checkpoint_dir, "actor")
902+
self.assertTrue(os.path.exists(actor_checkpoint_dir))
903+
# check step 2 should have no checkpoint
904+
checkpoint_dir = os.path.join(default_local_dir, "global_step_2")
905+
self.assertTrue(os.path.exists(checkpoint_dir))
906+
actor_checkpoint_dir = os.path.join(checkpoint_dir, "actor")
907+
self.assertFalse(os.path.exists(actor_checkpoint_dir))
908+
critic_checkpoint_dir = os.path.join(checkpoint_dir, "critic")
909+
self.assertFalse(os.path.exists(critic_checkpoint_dir))
894910
trainer_process.join(timeout=10)
895911
self.assertIn("model.safetensors", huggingface_dir_files)
896912

trinity/manager/synchronizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ async def _remove_previous_state_dict(self, previous_model_version: int) -> None
135135
self.logger.info(
136136
f"Removing previous checkpoint for sync at step {previous_model_version}."
137137
)
138-
shutil.rmtree(previous_state_dict_dir)
138+
shutil.rmtree(previous_state_dict_dir, ignore_errors=True)
139139

140140
async def _find_tinker_latest_state_dict(self) -> None:
141141
default_local_dir = self.config.checkpoint_job_dir

trinity/trainer/verl/fsdp_checkpoint_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
from trinity.manager.synchronizer import Synchronizer
5050
from trinity.trainer.verl_trainer import CheckpointMonitor
51+
from trinity.utils.log import get_logger
5152

5253

5354
class FSDPCheckpointManager(OldFSDPCheckpointManager):
@@ -62,6 +63,7 @@ class FSDPCheckpointManager(OldFSDPCheckpointManager):
6263

6364
def __init__(self, *args, ray_namespace: str = "", **kwargs):
6465
super().__init__(*args, **kwargs)
66+
self.logger = get_logger()
6567
self.synchronizer = Synchronizer.get_actor(namespace=ray_namespace)
6668
self.checkpoint_monitor = CheckpointMonitor.get_actor(
6769
namespace=ray_namespace,
@@ -439,6 +441,10 @@ def save_checkpoint(
439441
and local_path != self.previous_saved_paths[-1] # type: ignore
440442
): # last step may save twice
441443
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore
444+
self.logger.info(
445+
"Checkpoint manager is removing previous checkpoints at "
446+
+ str(self.previous_saved_paths[:keep_start]) # type: ignore
447+
)
442448
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore
443449
self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore
444450

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
from trinity.manager.synchronizer import Synchronizer
4242
from trinity.trainer.verl_trainer import CheckpointMonitor
43+
from trinity.utils.log import get_logger
4344

4445

4546
class MegatronCheckpointManager(OldMegatronCheckpointManager):
@@ -59,6 +60,7 @@ def __init__(
5960
*args,
6061
**kwargs,
6162
)
63+
self.logger = get_logger()
6264
self.synchronizer = Synchronizer.get_actor(namespace=ray_namespace)
6365
self.checkpoint_monitor = CheckpointMonitor.get_actor(
6466
namespace=ray_namespace,
@@ -340,6 +342,10 @@ def save_checkpoint(
340342
and local_path != self.previous_saved_paths[-1] # type: ignore
341343
): # last step may save twice
342344
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore
345+
self.logger.info(
346+
"Checkpoint manager is removing previous checkpoints at "
347+
+ str(self.previous_saved_paths[:keep_start]) # type: ignore
348+
)
343349
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore
344350
self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore
345351

trinity/trainer/verl_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ def _save_checkpoint(self, save_as_hf: bool = False):
498498
# make sure this flag is created before notifying the synchronizer
499499
# to avoid the synchronizer recognizing it as a state_dict-only checkpoint
500500
# TODO: use a better way to indicate full checkpoint
501+
os.makedirs(local_global_step_folder, exist_ok=True)
501502
flag_path = os.path.join(local_global_step_folder, ".full_checkpoint")
502503
with open(flag_path, "w") as f:
503504
f.write("")

0 commit comments

Comments
 (0)