Skip to content

Commit 7a1c526

Browse files
authored
Refactor state_dict_meta init (#90)
1 parent c85d853 commit 7a1c526

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

trinity/common/models/vllm_async_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
self.action_mask_method = tokenize_and_mask_messages_default
101101
else:
102102
self.action_mask_method = tokenize_and_mask_messages_hf
103+
self.state_dict_meta = None
103104
self.ckp_version = 0 # TODO: resume the value from the checkpoint
104105
self.api_server_host = None
105106
self.api_server_port = None
@@ -264,9 +265,11 @@ async def _collective_rpc(
264265
method, timeout, args, kwargs
265266
)
266267

267-
async def sync_model(self, update_weight_args_list) -> bool:
268+
async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
268269
"""Sync model weights to vLLM."""
269-
for args in update_weight_args_list:
270+
if self.state_dict_meta is None:
271+
self.state_dict_meta = update_weight_args_list
272+
for args in self.state_dict_meta:
270273
await self._collective_rpc("update_weight", args=args)
271274
self.logger.info("Sync model weights to vLLM successfully.")
272275
self.ckp_version += 1
@@ -282,7 +285,9 @@ async def init_process_group(
282285
backend: str = "nccl",
283286
timeout: int = 1200,
284287
update_with_checkpoint: bool = True,
288+
state_dict_meta: dict = None,
285289
):
290+
self.state_dict_meta = state_dict_meta
286291
return await self._collective_rpc(
287292
"init_process_group",
288293
args=(

trinity/common/models/vllm_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import re
1010
import threading
11-
from typing import List
11+
from typing import List, Optional, Tuple
1212

1313
import torch
1414
import vllm
@@ -85,6 +85,7 @@ def __init__(self, config: InferenceModelConfig):
8585
else:
8686
self.action_mask_method = tokenize_and_mask_messages_hf
8787
self.lock = threading.Lock()
88+
self.state_dict_meta = None
8889
self.ckp_version = 0 # TODO: resume the value from the checkpoint
8990

9091
def init_process_group(
@@ -97,7 +98,9 @@ def init_process_group(
9798
backend: str = "nccl",
9899
timeout: int = 1200,
99100
update_with_checkpoint: bool = True,
101+
state_dict_meta: dict = None,
100102
):
103+
self.state_dict_meta = state_dict_meta
101104
return self.llm.collective_rpc(
102105
"init_process_group",
103106
args=(
@@ -274,10 +277,12 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
274277
def has_api_server(self) -> bool:
275278
return False
276279

277-
def sync_model(self, update_weight_args_list) -> bool:
280+
def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
278281
"""Sync model weights to vLLM."""
282+
if self.state_dict_meta is None:
283+
self.state_dict_meta = update_weight_args_list
279284
with self.lock:
280-
for args in update_weight_args_list:
285+
for args in self.state_dict_meta:
281286
self.llm.collective_rpc("update_weight", args=args)
282287
self.logger.info("Sync model weights to vLLM successfully.")
283288
self.ckp_version += 1

trinity/explorer/explorer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def setup_weight_sync_group(
9696
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
9797
timeout=self.config.synchronizer.sync_timeout,
9898
update_with_checkpoint=self.use_checkpoint_weights_update,
99+
state_dict_meta=state_dict_meta,
99100
)
100101
for i, model in enumerate(self.models)
101102
]
@@ -119,9 +120,13 @@ def _init_runner_pool(self) -> RunnerPool:
119120
def _update_model_weight(self, state_dict: dict) -> None:
120121
# TODO: update model weight
121122
self.state_dict = state_dict
122-
update_weight_args_list = []
123-
for name, param in state_dict.items():
124-
update_weight_args_list.append((name, str(param.dtype), tuple(param.shape)))
123+
if self.state_dict_meta is None:
124+
update_weight_args_list = []
125+
for name, param in state_dict.items():
126+
update_weight_args_list.append((name, str(param.dtype), tuple(param.shape)))
127+
self.state_dict_meta = update_weight_args_list
128+
else:
129+
update_weight_args_list = None
125130
ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models])
126131
self.state_dict.clear()
127132

@@ -142,7 +147,8 @@ def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
142147
self.logger.error(f"Error when loading state_dict: {e}")
143148

144149
def _nccl_weights_update(self):
145-
ray.get([model.sync_model.remote(self.state_dict_meta) for model in self.models])
150+
assert self.state_dict_meta is not None
151+
ray.get([model.sync_model.remote() for model in self.models])
146152

147153
def prepare(self) -> None:
148154
"""Preparation before running."""

0 commit comments

Comments
 (0)