Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions trinity/common/models/vllm_async_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
self.action_mask_method = tokenize_and_mask_messages_default
else:
self.action_mask_method = tokenize_and_mask_messages_hf
self.state_dict_meta = None
self.ckp_version = 0 # TODO: resume the value from the checkpoint
self.api_server_host = None
self.api_server_port = None
Expand Down Expand Up @@ -264,9 +265,11 @@ async def _collective_rpc(
method, timeout, args, kwargs
)

async def sync_model(self, update_weight_args_list) -> bool:
async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
"""Sync model weights to vLLM."""
for args in update_weight_args_list:
if self.state_dict_meta is None:
self.state_dict_meta = update_weight_args_list
for args in self.state_dict_meta:
await self._collective_rpc("update_weight", args=args)
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
Expand All @@ -282,7 +285,9 @@ async def init_process_group(
backend: str = "nccl",
timeout: int = 1200,
update_with_checkpoint: bool = True,
state_dict_meta: dict = None,
):
self.state_dict_meta = state_dict_meta
return await self._collective_rpc(
"init_process_group",
args=(
Expand Down
11 changes: 8 additions & 3 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import re
import threading
from typing import List
from typing import List, Optional, Tuple

import torch
import vllm
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(self, config: InferenceModelConfig):
else:
self.action_mask_method = tokenize_and_mask_messages_hf
self.lock = threading.Lock()
self.state_dict_meta = None
self.ckp_version = 0 # TODO: resume the value from the checkpoint

def init_process_group(
Expand All @@ -97,7 +98,9 @@ def init_process_group(
backend: str = "nccl",
timeout: int = 1200,
update_with_checkpoint: bool = True,
state_dict_meta: dict = None,
):
self.state_dict_meta = state_dict_meta
return self.llm.collective_rpc(
"init_process_group",
args=(
Expand Down Expand Up @@ -274,10 +277,12 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
def has_api_server(self) -> bool:
return False

def sync_model(self, update_weight_args_list) -> bool:
def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
"""Sync model weights to vLLM."""
if self.state_dict_meta is None:
self.state_dict_meta = update_weight_args_list
with self.lock:
for args in update_weight_args_list:
for args in self.state_dict_meta:
self.llm.collective_rpc("update_weight", args=args)
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
Expand Down
14 changes: 10 additions & 4 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def setup_weight_sync_group(
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
timeout=self.config.synchronizer.sync_timeout,
update_with_checkpoint=self.use_checkpoint_weights_update,
state_dict_meta=state_dict_meta,
)
for i, model in enumerate(self.models)
]
Expand All @@ -119,9 +120,13 @@ def _init_runner_pool(self) -> RunnerPool:
def _update_model_weight(self, state_dict: dict) -> None:
# TODO: update model weight
self.state_dict = state_dict
update_weight_args_list = []
for name, param in state_dict.items():
update_weight_args_list.append((name, str(param.dtype), tuple(param.shape)))
if self.state_dict_meta is None:
update_weight_args_list = []
for name, param in state_dict.items():
update_weight_args_list.append((name, str(param.dtype), tuple(param.shape)))
self.state_dict_meta = update_weight_args_list
else:
update_weight_args_list = None
ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models])
self.state_dict.clear()

Expand All @@ -142,7 +147,8 @@ def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
self.logger.error(f"Error when loading state_dict: {e}")

def _nccl_weights_update(self):
ray.get([model.sync_model.remote(self.state_dict_meta) for model in self.models])
assert self.state_dict_meta is not None
ray.get([model.sync_model.remote() for model in self.models])

def prepare(self) -> None:
"""Preparation before running."""
Expand Down