diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 735255ecf2..1ba7731efc 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -43,10 +43,10 @@ def run(self) -> List[Experience]: @ray.remote class DummyModel(InferenceModel): - def sync_model(self, update_weight_args_list): + def sync_model(self, model_version, update_weight_args_list): return True - def get_ckp_version(self): + def get_model_version(self): return 0 def init_process_group( @@ -65,10 +65,10 @@ def init_process_group( @ray.remote class DummyAuxiliaryModel(InferenceModel): - def sync_model(self, update_weight_args_list): + def sync_model(self, model_version, update_weight_args_list): return True - def get_ckp_version(self): + def get_model_version(self): return 0 def init_process_group( diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 250ea3eb40..3ef01e2ccd 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1,4 +1,5 @@ """Tests for trainer.""" + import multiprocessing import os import shutil @@ -83,14 +84,12 @@ def test_trainer(self): self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) ray.shutdown(_exiting_interpreter=True) # check checkpoint - from trinity.common.models.utils import get_checkpoint_dir_with_step_num - - checkpoint_step_4 = get_checkpoint_dir_with_step_num( + checkpoint_step_4, _ = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, step_num=4, ) - checkpoint_step_8 = get_checkpoint_dir_with_step_num( + checkpoint_step_8, _ = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, step_num=8, @@ -158,13 +157,12 @@ def test_trainer(self): self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) ray.shutdown(_exiting_interpreter=True) # check checkpoint - from trinity.common.models.utils import get_checkpoint_dir_with_step_num - checkpoint_step_4 = get_checkpoint_dir_with_step_num( + checkpoint_step_4, step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, - step_num=4, ) + self.assertEqual(step_num, 4) self.assertTrue(os.path.exists(checkpoint_step_4)) def tearDown(self): @@ -374,19 +372,20 @@ def test_fully_async_mode(self): explorer2_cache = CacheManager(explorer2_config) cache = explorer2_cache.load_explorer() self.assertEqual(cache["latest_iteration"], 4) - self.assertIsNotNone( + # check the lastest checkpoint + self.assertEqual( get_checkpoint_dir_with_step_num( checkpoint_root_path=explorer1_config.checkpoint_job_dir, trainer_type="verl", - step_num=8, - ) + )[1], + 8, ) - self.assertIsNotNone( + self.assertEqual( get_checkpoint_dir_with_step_num( checkpoint_root_path=explorer2_config.checkpoint_job_dir, trainer_type="verl", - step_num=8, - ) + )[1], + 8, ) ray.shutdown() diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index cb15b1ae3d..1e3eb87058 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -49,7 +49,7 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex raise NotImplementedError @abstractmethod - def get_ckp_version(self) -> int: + def get_model_version(self) -> int: """Get the checkpoint version.""" def get_available_address(self) -> Tuple[str, int]: @@ -99,8 +99,10 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: else: return ray.get(self.model.convert_messages_to_experience.remote(messages)) - def get_ckp_version(self) -> int: - return ray.get(self.model.get_ckp_version.remote()) + @property + def model_version(self) -> int: + """Get the version of the model.""" + return ray.get(self.model.get_model_version.remote()) def get_openai_client(self) -> openai.OpenAI: """Get the openai client. diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 5cc770e64f..087b190e86 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -105,16 +105,21 @@ def get_checkpoint_dir_with_step_num( checkpoint_root_path: str, trainer_type: str = "verl", step_num: Optional[int] = None, -) -> str: +) -> Tuple[str, int]: """Get the checkpoint directory from a root checkpoint directory. Args: checkpoint_root_path (str): The root checkpoint directory. trainer_type (str): The trainer type. Only support "verl" for now. - step_num (Optional[int], optional): The step number. Defaults to None. + step_num (Optional[int], optional): The step number. If specified, + load the checkpoint with the specified step number. If None, + load the latest checkpoint. Defaults to None. + + Returns: + Tuple[str, int]: The checkpoint directory and the step number of the checkpoint. """ if trainer_type == "verl": - return get_verl_checkpoint_dir(checkpoint_path=checkpoint_root_path, step_num=step_num) + return get_verl_checkpoint_info(checkpoint_path=checkpoint_root_path, step_num=step_num) else: raise NotImplementedError(f"Unsupported trainer type {trainer_type}") @@ -144,8 +149,20 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): raise ValueError(f"Unsupported placement: {placement}") -def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None) -> str: - """Get the checkpoint directory from a Verl root checkpoint directory.""" +def get_verl_checkpoint_info( + checkpoint_path: str, step_num: Optional[int] = None +) -> Tuple[str, int]: + """Get the checkpoint directory from a Verl root checkpoint directory. + + Args: + checkpoint_path (str): The root checkpoint directory. + step_num (Optional[int], optional): The step number. If specified, + load the checkpoint with the specified step number. If None, + load the latest checkpoint. Defaults to None. + + Returns: + Tuple[str, int]: The checkpoint directory and the step number of the checkpoint. + """ if step_num is None: # load latest checkpoint iteration_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt") @@ -154,12 +171,12 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None iteration_file, "r", encoding="utf-8" ) as f: # TODO: this file may be modified simultaneously iteration = f.read().strip() - return os.path.join(checkpoint_path, f"global_step_{iteration}") + return os.path.join(checkpoint_path, f"global_step_{iteration}"), int(iteration) else: raise FileNotFoundError(f"No iteration file found in {checkpoint_path}") else: # load specific iteration checkpoint - return os.path.join(checkpoint_path, f"global_step_{step_num}") + return os.path.join(checkpoint_path, f"global_step_{step_num}"), step_num # copy from verl/scripts/model_merger.py diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 0806bc9c7d..f3253bcf4b 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -102,7 +102,7 @@ def __init__( else: self.action_mask_method = tokenize_and_mask_messages_hf self.state_dict_meta = None - self.ckp_version = 0 # TODO: resume the value from the checkpoint + self.model_version = 0 # TODO: resume the value from the checkpoint self.api_server_host = None self.api_server_port = None @@ -266,13 +266,15 @@ async def _collective_rpc( method, timeout, args, kwargs ) - async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool: + async def sync_model( + self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None + ) -> bool: """Sync model weights to vLLM.""" if update_weight_args_list is not None: await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") - self.ckp_version += 1 + self.model_version = model_version return True async def init_process_group( @@ -352,8 +354,8 @@ async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]: async def reset_prefix_cache(self) -> None: await self.async_llm.reset_prefix_cache() - def get_ckp_version(self) -> int: - return self.ckp_version + def get_model_version(self) -> int: + return self.model_version async def sleep(self, level: int = 1) -> None: await self.async_llm.sleep(level=level) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 59211f198a..643a124f72 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -87,7 +87,7 @@ def __init__(self, config: InferenceModelConfig): self.action_mask_method = tokenize_and_mask_messages_hf self.lock = threading.Lock() self.state_dict_meta = None - self.ckp_version = 0 # TODO: resume the value from the checkpoint + self.model_version = 0 # TODO: resume the value from the checkpoint def init_process_group( self, @@ -278,14 +278,16 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: def has_api_server(self) -> bool: return False - def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool: + def sync_model( + self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None + ) -> bool: """Sync model weights to vLLM.""" if update_weight_args_list is not None: self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") - self.ckp_version += 1 + self.model_version = model_version return True - def get_ckp_version(self) -> int: - return self.ckp_version + def get_model_version(self) -> int: + return self.model_version diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 76869154ed..a54a367b04 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -124,7 +124,7 @@ def _init_runner_pool(self) -> RunnerPool: self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners") return RunnerPool(self.config, self.models, self.auxiliary_models) - async def _update_model_weight(self, state_dict: dict) -> None: + async def _update_model_weight(self, step_num: int, state_dict: dict) -> None: # TODO: update model weight self.state_dict = state_dict if self.state_dict_meta is None: @@ -135,14 +135,14 @@ async def _update_model_weight(self, state_dict: dict) -> None: else: update_weight_args_list = None await asyncio.gather( - *[model.sync_model.remote(update_weight_args_list) for model in self.models] + *[model.sync_model.remote(step_num, update_weight_args_list) for model in self.models] ) self.state_dict.clear() async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: # TODO: support more checkpoint types try: - checkpoint_dir = get_checkpoint_dir_with_step_num( + checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, step_num=step_num, @@ -150,14 +150,16 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> No if checkpoint_dir == self.old_checkpoint: return model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor")) - await self._update_model_weight(model_weights) + await self._update_model_weight(checkpoint_step_num, model_weights) self.old_checkpoint = checkpoint_dir except Exception as e: self.logger.warning(f"Fail to load checkpoint: {e}") async def _nccl_weights_update(self): assert self.state_dict_meta is not None - await asyncio.gather(*[model.sync_model.remote() for model in self.models]) + await asyncio.gather( + *[model.sync_model.remote(self.explore_step_num) for model in self.models] + ) async def prepare(self) -> None: """Preparation before running.""" diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index b63a9ffadf..f4ed4dda50 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -85,7 +85,7 @@ def run_task(self, task: Task) -> Status: if not hasattr(exp, "info") or exp.info is None: exp.info = {} - exp.info["model_version"] = self.model_wrapper.get_ckp_version() + exp.info["model_version"] = self.model_wrapper.model_version if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {}