diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 0c858ae505..19dda755bd 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -475,6 +475,7 @@ trainer: use_dynamic_bsz: true max_token_len_per_gpu: 16384 ulysses_sequence_parallel_size: 1 + max_checkpoints_to_keep: 5 trainer_config: null ``` @@ -499,6 +500,7 @@ trainer: - `use_dynamic_bsz`: Whether to use dynamic batch size. - `max_token_len_per_gpu`: The maximum number of tokens to be processed in forward and backward when updating the policy. Effective when `use_dynamic_bsz=true`. - `ulysses_sequence_parallel_size`: Sequence parallel size. +- `max_checkpoints_to_keep`: Maximum number of checkpoints to keep. Older checkpoints will be deleted. If not specified, all checkpoints will be kept. - `trainer_config`: The trainer configuration provided inline. --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 6d96bff547..d130517920 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -472,6 +472,7 @@ trainer: use_dynamic_bsz: true max_token_len_per_gpu: 16384 ulysses_sequence_parallel_size: 1 + max_checkpoints_to_keep: 5 trainer_config: null ``` @@ -496,6 +497,7 @@ trainer: - `use_dynamic_bsz`: 是否使用动态批量大小。 - `max_token_len_per_gpu`: 训练过程中,每个 GPU 最大 token 长度; 当 `use_dynamic_bsz=true` 时生效。 - `ulysses_sequence_parallel_size`: 序列并行的并行度,即用于分割单个序列的 GPU 数量。 +- `max_checkpoints_to_keep`: 保留的最大检查点数量。超过此数量后,最旧的检查点将被删除。如果未指定,则将保留所有检查点。 - `trainer_config`: 内联提供的 trainer 配置。 --- diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 3cd4c8f856..136d95c366 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -344,6 +344,10 @@ def test_trainer(self, mock_load): # sft warmup stage sft_config = stage_configs[0] + self.assertEqual( + sft_config.synchronizer.sync_interval, + sft_config.trainer.save_interval, + ) parser = TensorBoardParser(os.path.join(sft_config.monitor.cache_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") self.assertEqual(len(rollout_metrics), 0) @@ -374,11 +378,15 @@ def test_trainer(self, mock_load): self.assertEqual(parser.metric_min_step(response_metrics[0]), 1) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) # test save checkpoint when sft finish + for i in range(3): + self.assertFalse( + os.path.exists(os.path.join(sft_config.checkpoint_job_dir, f"global_step_{i}")) + ) self.assertEqual( get_checkpoint_dir_with_step_num( - checkpoint_root_path=sft_config.checkpoint_job_dir, trainer_type="verl", step_num=2 + checkpoint_root_path=sft_config.checkpoint_job_dir, trainer_type="verl", step_num=3 )[1], - 2, + 3, ) # test save checkpoint at last step checkpoint_dir, step_num = get_checkpoint_dir_with_step_num( @@ -749,7 +757,7 @@ def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) self.config = get_template_config() - self.config.buffer.total_epochs = 1 + self.config.buffer.total_steps = 6 self.config.buffer.batch_size = 4 self.config.model.model_path = get_model_path() self.config.explorer.rollout_model.engine_type = "vllm_async" @@ -762,21 +770,20 @@ def setUp(self): self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT self.config.explorer.eval_interval = 4 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - self.config.trainer.save_interval = 4 + self.config.trainer.save_interval = 2 self.config.trainer.save_hf_checkpoint = "last" self.config.trainer.trainer_strategy = self.strategy + self.config.trainer.max_checkpoints_to_keep = 2 self.config.check_and_update() self.process_list = [] - def test_trainer(self): + def test_trainer(self): # noqa: C901 """Test the checkpoint saving.""" _trainer_config = self.config.trainer.trainer_config if self.strategy == "megatron": _trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2 _trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2 _trainer_config.critic.megatron.tensor_model_parallel_size = 2 - _trainer_config.trainer.max_actor_ckpt_to_keep = 2 - _trainer_config.trainer.max_critic_ckpt_to_keep = 2 stop_event = multiprocessing.Event() trainer_process = multiprocessing.Process(target=run_both, args=(self.config, stop_event)) @@ -839,6 +846,10 @@ def test_trainer(self): # print(f"State dict check at {state_dict_iteration} iteration passed.") # for debug if checkpoint_iteration > 0: + flag_file = os.path.join( + default_local_dir, f"global_step_{checkpoint_iteration}", ".full_checkpoint" + ) + self.assertTrue(os.path.exists(flag_file)) for sub_dir_name in ["critic", "actor"]: iteration_dir = os.path.join( default_local_dir, f"global_step_{checkpoint_iteration}", sub_dir_name @@ -882,6 +893,28 @@ def test_trainer(self): # print(f"Checkpoint check at {checkpoint_iteration} iteration passed.") # for debug if not stop_event.is_set(): self.fail("Training process failed to stop.") + # check only full checkpoint dirs are kept + for sync_step in [1, 3, 5]: + state_dict_dir = os.path.join(default_local_dir, f"global_step_{sync_step}") + self.assertFalse( + os.path.exists(state_dict_dir), + f"Found unexpected state dict dir at step {sync_step}", + ) + for checkpoint_step in [4, 6]: + checkpoint_dir = os.path.join(default_local_dir, f"global_step_{checkpoint_step}") + self.assertTrue( + os.path.exists(checkpoint_dir), + f"Missing expected checkpoint dir at step {checkpoint_step}", + ) + actor_checkpoint_dir = os.path.join(checkpoint_dir, "actor") + self.assertTrue(os.path.exists(actor_checkpoint_dir)) + # check step 2 should have no checkpoint + checkpoint_dir = os.path.join(default_local_dir, "global_step_2") + self.assertTrue(os.path.exists(checkpoint_dir)) + actor_checkpoint_dir = os.path.join(checkpoint_dir, "actor") + self.assertFalse(os.path.exists(actor_checkpoint_dir)) + critic_checkpoint_dir = os.path.join(checkpoint_dir, "critic") + self.assertFalse(os.path.exists(critic_checkpoint_dir)) trainer_process.join(timeout=10) self.assertIn("model.safetensors", huggingface_dir_files) diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index ebe5978a07..5206d3c513 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -63,6 +63,27 @@ def default_config(cls) -> Dict: "entropy_loss_fn": "none", } + @classmethod + def check_config(cls, config: Config) -> None: + if config.mode == "train": + if ( + config.buffer.trainer_input.experience_buffer is None + or not config.buffer.trainer_input.experience_buffer.path + ): + raise ValueError( + "`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == sft`" + ) + elif config.mode in ["both", "explore"]: + raise ValueError(f"SFT does not support `{config.mode}` mode") + + if config.synchronizer.sync_method != SyncMethod.CHECKPOINT: + config.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + "SFT only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." + ) + + config.synchronizer.sync_interval = config.trainer.save_interval + class PPOAlgorithm(AlgorithmType): """PPO Algorithm.""" @@ -232,6 +253,7 @@ def check_config(cls, config: Config) -> None: logger.warning( "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." ) + config.synchronizer.sync_interval = config.trainer.save_interval if config.algorithm.repeat_times != 2: config.algorithm.repeat_times = 2 # Fake repeat times if config.algorithm.kl_loss_fn in {"none", None}: diff --git a/trinity/common/config.py b/trinity/common/config.py index c80d84b61e..d7e8e3a059 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -735,6 +735,7 @@ class TrainerConfig: # TODO: extract more train-related params from underlying trainer engine save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED + max_checkpoints_to_keep: Optional[int] = None trainer_config: Any = field(default_factory=dict) trainer_config_path: str = "" # deprecated, use `trainer_config` instead diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 445f1f63ba..5e3b9f1020 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -199,6 +199,10 @@ def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, T Args: checkpoint_dir (str): The checkpoint directory. trainer_type (str): The trainer type. Only support "verl" for now. + + Returns: + Union[dict, Tuple[str, str]]: The state dict. If the checkpoint uses + megatron dist checkpointing, return a tuple of (method, checkpoint_dir). """ if config.trainer_type == "verl": strategy = config.trainer_strategy diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 15220b18dc..e84f2c1746 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -68,6 +68,7 @@ def update_weight(self): if self._weight_update_rank == 0: state_dict, model_version = ray.get(self.synchronizer.get_model_state_dict.remote()) if isinstance(state_dict, tuple): + # currently only megatron return a tuple method, checkpoint_dir = state_dict if method == "megatron": if self._checkpoint_converter is None: diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 5db971245d..96e9d7f524 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -413,6 +413,9 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.trainer.group_name = config.group self.trainer.experiment_name = config.name self.trainer.default_local_dir = config.checkpoint_job_dir + if config.trainer.max_checkpoints_to_keep is not None: + self.trainer.max_actor_ckpt_to_keep = config.trainer.max_checkpoints_to_keep + self.trainer.max_critic_ckpt_to_keep = config.trainer.max_checkpoints_to_keep if not config.continue_from_checkpoint: self.trainer.resume_mode = "disable" else: diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index c0b913812e..9b7cfd414e 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -2,6 +2,7 @@ import asyncio import os +import shutil from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -95,13 +96,14 @@ async def _find_verl_latest_state_dict(self) -> None: ) while True: if os.path.exists(local_latest_state_dict_iteration): + current_model_version = self.model_version try: with open(local_latest_state_dict_iteration, "r") as f: latest_model_version = int(f.read().strip()) except (IOError, ValueError) as e: self.logger.warning(f"Failed to read or parse state dict iteration file: {e}") continue - if latest_model_version > self.model_version: + if latest_model_version > current_model_version: self.logger.info( f"Synchronizer has found a new model state dict at step {latest_model_version}." ) @@ -119,8 +121,22 @@ async def _find_verl_latest_state_dict(self) -> None: f"Synchronizer has loaded model state dict from checkpoint {latest_model_version}." ) await self.set_model_state_dict(model_state_dict, latest_model_version) + # remove the previous checkpoints to save disk space + await self._remove_previous_state_dict(current_model_version) await asyncio.sleep(1) + async def _remove_previous_state_dict(self, previous_model_version: int) -> None: + previous_state_dict_dir = os.path.join( + self.config.checkpoint_job_dir, f"global_step_{previous_model_version}" + ) + if os.path.exists(previous_state_dict_dir): + # check if it's a full checkpoint, only remove checkpoints for sync + if not os.path.exists(os.path.join(previous_state_dict_dir, ".full_checkpoint")): + self.logger.info( + f"Removing previous checkpoint for sync at step {previous_model_version}." + ) + shutil.rmtree(previous_state_dict_dir, ignore_errors=True) + async def _find_tinker_latest_state_dict(self) -> None: default_local_dir = self.config.checkpoint_job_dir local_latest_state_dict_iteration = os.path.join( @@ -320,9 +336,7 @@ async def get_latest_model_version(self) -> int: async with self._ready_condition: return self.model_version - async def ready_to_nccl_sync( - self, module: str, trainer_step: Optional[int] = None - ) -> Union[int, None]: + async def ready_to_nccl_sync(self, module: str, trainer_step: int) -> Union[int, None]: """ Prepare for NCCL-based synchronization between modules. @@ -330,7 +344,7 @@ async def ready_to_nccl_sync( Args: module: Either 'trainer' or 'explorer'. - trainer_step: Optional step number from the trainer. + trainer_step: Step number from the trainer. Returns: The model version if both sides are ready; otherwise None. diff --git a/trinity/trainer/tinker_trainer.py b/trinity/trainer/tinker_trainer.py index e355630063..1265348491 100644 --- a/trinity/trainer/tinker_trainer.py +++ b/trinity/trainer/tinker_trainer.py @@ -282,6 +282,15 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa f"global_step_{self.train_step_num}", ) os.makedirs(local_path, exist_ok=True) + + # save a flag to indicate this is a full checkpoint dir + # make sure this flag is created before notifying the synchronizer + # to avoid the synchronizer recognizing it as a state_dict-only checkpoint + # TODO: use a better way to indicate full checkpoint + flag_path = os.path.join(local_path, ".full_checkpoint") + with open(flag_path, "w") as f: + f.write("") + remote_checkpoint_path = os.path.join(local_path, "remote_checkpoint_path.txt") with open(remote_checkpoint_path, "w") as f: f.write(self.latest_remote_checkpoint_path) diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 4c3a6aba57..38391bd679 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -48,6 +48,7 @@ from trinity.manager.synchronizer import Synchronizer from trinity.trainer.verl_trainer import CheckpointMonitor +from trinity.utils.log import get_logger class FSDPCheckpointManager(OldFSDPCheckpointManager): @@ -62,6 +63,7 @@ class FSDPCheckpointManager(OldFSDPCheckpointManager): def __init__(self, *args, ray_namespace: str = "", **kwargs): super().__init__(*args, **kwargs) + self.logger = get_logger() self.synchronizer = Synchronizer.get_actor(namespace=ray_namespace) self.checkpoint_monitor = CheckpointMonitor.get_actor( namespace=ray_namespace, @@ -439,6 +441,10 @@ def save_checkpoint( and local_path != self.previous_saved_paths[-1] # type: ignore ): # last step may save twice keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore + self.logger.info( + "Checkpoint manager is removing previous checkpoints at " + + str(self.previous_saved_paths[:keep_start]) # type: ignore + ) self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index 4fadcae477..6659975307 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -40,6 +40,7 @@ from trinity.manager.synchronizer import Synchronizer from trinity.trainer.verl_trainer import CheckpointMonitor +from trinity.utils.log import get_logger class MegatronCheckpointManager(OldMegatronCheckpointManager): @@ -59,6 +60,7 @@ def __init__( *args, **kwargs, ) + self.logger = get_logger() self.synchronizer = Synchronizer.get_actor(namespace=ray_namespace) self.checkpoint_monitor = CheckpointMonitor.get_actor( namespace=ray_namespace, @@ -340,6 +342,10 @@ def save_checkpoint( and local_path != self.previous_saved_paths[-1] # type: ignore ): # last step may save twice keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore + self.logger.info( + "Checkpoint manager is removing previous checkpoints at " + + str(self.previous_saved_paths[:keep_start]) # type: ignore + ) self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 8583fd3363..e5870e8e3b 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -494,6 +494,15 @@ def _save_checkpoint(self, save_as_hf: bool = False): self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" ) + # save a flag to indicate this is a full checkpoint dir + # make sure this flag is created before notifying the synchronizer + # to avoid the synchronizer recognizing it as a state_dict-only checkpoint + # TODO: use a better way to indicate full checkpoint + os.makedirs(local_global_step_folder, exist_ok=True) + flag_path = os.path.join(local_global_step_folder, ".full_checkpoint") + with open(flag_path, "w") as f: + f.write("") + self.logger.info(f"local_global_step_folder: {local_global_step_folder}") actor_local_path = os.path.join(local_global_step_folder, "actor")