diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 27faa4c44a..177e0e1a81 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -100,6 +100,7 @@ def __init__( self.action_mask_method = tokenize_and_mask_messages_default else: self.action_mask_method = tokenize_and_mask_messages_hf + self.state_dict_meta = None self.ckp_version = 0 # TODO: resume the value from the checkpoint self.api_server_host = None self.api_server_port = None @@ -264,9 +265,11 @@ async def _collective_rpc( method, timeout, args, kwargs ) - async def sync_model(self, update_weight_args_list) -> bool: + async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool: """Sync model weights to vLLM.""" - for args in update_weight_args_list: + if self.state_dict_meta is None: + self.state_dict_meta = update_weight_args_list + for args in self.state_dict_meta: await self._collective_rpc("update_weight", args=args) self.logger.info("Sync model weights to vLLM successfully.") self.ckp_version += 1 @@ -282,7 +285,9 @@ async def init_process_group( backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, + state_dict_meta: dict = None, ): + self.state_dict_meta = state_dict_meta return await self._collective_rpc( "init_process_group", args=( diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 9459cd7511..c999a61bfa 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -8,7 +8,7 @@ import os import re import threading -from typing import List +from typing import List, Optional, Tuple import torch import vllm @@ -85,6 +85,7 @@ def __init__(self, config: InferenceModelConfig): else: self.action_mask_method = tokenize_and_mask_messages_hf self.lock = threading.Lock() + self.state_dict_meta = None self.ckp_version = 0 # TODO: resume the value from the checkpoint def init_process_group( @@ -97,7 +98,9 @@ def init_process_group( backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, + state_dict_meta: dict = None, ): + self.state_dict_meta = state_dict_meta return self.llm.collective_rpc( "init_process_group", args=( @@ -274,10 +277,12 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: def has_api_server(self) -> bool: return False - def sync_model(self, update_weight_args_list) -> bool: + def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool: """Sync model weights to vLLM.""" + if self.state_dict_meta is None: + self.state_dict_meta = update_weight_args_list with self.lock: - for args in update_weight_args_list: + for args in self.state_dict_meta: self.llm.collective_rpc("update_weight", args=args) self.logger.info("Sync model weights to vLLM successfully.") self.ckp_version += 1 diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 87657c9b42..26ee0b53c2 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -96,6 +96,7 @@ def setup_weight_sync_group( group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, timeout=self.config.synchronizer.sync_timeout, update_with_checkpoint=self.use_checkpoint_weights_update, + state_dict_meta=state_dict_meta, ) for i, model in enumerate(self.models) ] @@ -119,9 +120,13 @@ def _init_runner_pool(self) -> RunnerPool: def _update_model_weight(self, state_dict: dict) -> None: # TODO: update model weight self.state_dict = state_dict - update_weight_args_list = [] - for name, param in state_dict.items(): - update_weight_args_list.append((name, str(param.dtype), tuple(param.shape))) + if self.state_dict_meta is None: + update_weight_args_list = [] + for name, param in state_dict.items(): + update_weight_args_list.append((name, str(param.dtype), tuple(param.shape))) + self.state_dict_meta = update_weight_args_list + else: + update_weight_args_list = None ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models]) self.state_dict.clear() @@ -142,7 +147,8 @@ def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: self.logger.error(f"Error when loading state_dict: {e}") def _nccl_weights_update(self): - ray.get([model.sync_model.remote(self.state_dict_meta) for model in self.models]) + assert self.state_dict_meta is not None + ray.get([model.sync_model.remote() for model in self.models]) def prepare(self) -> None: """Preparation before running."""