Skip to content

Commit 670e49a

Browse files
authored
Record model version during model weight sync (#102)
1 parent 81330f8 commit 670e49a

File tree

8 files changed

+67
-43
lines changed

8 files changed

+67
-43
lines changed

tests/explorer/runner_pool_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ def run(self) -> List[Experience]:
4343

4444
@ray.remote
4545
class DummyModel(InferenceModel):
46-
def sync_model(self, update_weight_args_list):
46+
def sync_model(self, model_version, update_weight_args_list):
4747
return True
4848

49-
def get_ckp_version(self):
49+
def get_model_version(self):
5050
return 0
5151

5252
def init_process_group(
@@ -65,10 +65,10 @@ def init_process_group(
6565

6666
@ray.remote
6767
class DummyAuxiliaryModel(InferenceModel):
68-
def sync_model(self, update_weight_args_list):
68+
def sync_model(self, model_version, update_weight_args_list):
6969
return True
7070

71-
def get_ckp_version(self):
71+
def get_model_version(self):
7272
return 0
7373

7474
def init_process_group(

tests/trainer/trainer_test.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Tests for trainer."""
2+
23
import multiprocessing
34
import os
45
import shutil
@@ -83,14 +84,12 @@ def test_trainer(self):
8384
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
8485
ray.shutdown(_exiting_interpreter=True)
8586
# check checkpoint
86-
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
87-
88-
checkpoint_step_4 = get_checkpoint_dir_with_step_num(
87+
checkpoint_step_4, _ = get_checkpoint_dir_with_step_num(
8988
checkpoint_root_path=self.config.checkpoint_job_dir,
9089
trainer_type=self.config.trainer.trainer_type,
9190
step_num=4,
9291
)
93-
checkpoint_step_8 = get_checkpoint_dir_with_step_num(
92+
checkpoint_step_8, _ = get_checkpoint_dir_with_step_num(
9493
checkpoint_root_path=self.config.checkpoint_job_dir,
9594
trainer_type=self.config.trainer.trainer_type,
9695
step_num=8,
@@ -156,13 +155,12 @@ def test_trainer(self):
156155
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
157156
ray.shutdown(_exiting_interpreter=True)
158157
# check checkpoint
159-
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
160158

161-
checkpoint_step_4 = get_checkpoint_dir_with_step_num(
159+
checkpoint_step_4, step_num = get_checkpoint_dir_with_step_num(
162160
checkpoint_root_path=self.config.checkpoint_job_dir,
163161
trainer_type=self.config.trainer.trainer_type,
164-
step_num=4,
165162
)
163+
self.assertEqual(step_num, 4)
166164
self.assertTrue(os.path.exists(checkpoint_step_4))
167165

168166
def tearDown(self):
@@ -372,19 +370,20 @@ def test_fully_async_mode(self):
372370
explorer2_cache = CacheManager(explorer2_config)
373371
cache = explorer2_cache.load_explorer()
374372
self.assertEqual(cache["latest_iteration"], 4)
375-
self.assertIsNotNone(
373+
# check the lastest checkpoint
374+
self.assertEqual(
376375
get_checkpoint_dir_with_step_num(
377376
checkpoint_root_path=explorer1_config.checkpoint_job_dir,
378377
trainer_type="verl",
379-
step_num=8,
380-
)
378+
)[1],
379+
8,
381380
)
382-
self.assertIsNotNone(
381+
self.assertEqual(
383382
get_checkpoint_dir_with_step_num(
384383
checkpoint_root_path=explorer2_config.checkpoint_job_dir,
385384
trainer_type="verl",
386-
step_num=8,
387-
)
385+
)[1],
386+
8,
388387
)
389388
ray.shutdown()
390389

trinity/common/models/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex
4949
raise NotImplementedError
5050

5151
@abstractmethod
52-
def get_ckp_version(self) -> int:
52+
def get_model_version(self) -> int:
5353
"""Get the checkpoint version."""
5454

5555
def get_available_address(self) -> Tuple[str, int]:
@@ -99,8 +99,10 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
9999
else:
100100
return ray.get(self.model.convert_messages_to_experience.remote(messages))
101101

102-
def get_ckp_version(self) -> int:
103-
return ray.get(self.model.get_ckp_version.remote())
102+
@property
103+
def model_version(self) -> int:
104+
"""Get the version of the model."""
105+
return ray.get(self.model.get_model_version.remote())
104106

105107
def get_openai_client(self) -> openai.OpenAI:
106108
"""Get the openai client.

trinity/common/models/utils.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,21 @@ def get_checkpoint_dir_with_step_num(
105105
checkpoint_root_path: str,
106106
trainer_type: str = "verl",
107107
step_num: Optional[int] = None,
108-
) -> str:
108+
) -> Tuple[str, int]:
109109
"""Get the checkpoint directory from a root checkpoint directory.
110110
111111
Args:
112112
checkpoint_root_path (str): The root checkpoint directory.
113113
trainer_type (str): The trainer type. Only support "verl" for now.
114-
step_num (Optional[int], optional): The step number. Defaults to None.
114+
step_num (Optional[int], optional): The step number. If specified,
115+
load the checkpoint with the specified step number. If None,
116+
load the latest checkpoint. Defaults to None.
117+
118+
Returns:
119+
Tuple[str, int]: The checkpoint directory and the step number of the checkpoint.
115120
"""
116121
if trainer_type == "verl":
117-
return get_verl_checkpoint_dir(checkpoint_path=checkpoint_root_path, step_num=step_num)
122+
return get_verl_checkpoint_info(checkpoint_path=checkpoint_root_path, step_num=step_num)
118123
else:
119124
raise NotImplementedError(f"Unsupported trainer type {trainer_type}")
120125

@@ -144,8 +149,20 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
144149
raise ValueError(f"Unsupported placement: {placement}")
145150

146151

147-
def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None) -> str:
148-
"""Get the checkpoint directory from a Verl root checkpoint directory."""
152+
def get_verl_checkpoint_info(
153+
checkpoint_path: str, step_num: Optional[int] = None
154+
) -> Tuple[str, int]:
155+
"""Get the checkpoint directory from a Verl root checkpoint directory.
156+
157+
Args:
158+
checkpoint_path (str): The root checkpoint directory.
159+
step_num (Optional[int], optional): The step number. If specified,
160+
load the checkpoint with the specified step number. If None,
161+
load the latest checkpoint. Defaults to None.
162+
163+
Returns:
164+
Tuple[str, int]: The checkpoint directory and the step number of the checkpoint.
165+
"""
149166
if step_num is None:
150167
# load latest checkpoint
151168
iteration_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt")
@@ -154,12 +171,12 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None
154171
iteration_file, "r", encoding="utf-8"
155172
) as f: # TODO: this file may be modified simultaneously
156173
iteration = f.read().strip()
157-
return os.path.join(checkpoint_path, f"global_step_{iteration}")
174+
return os.path.join(checkpoint_path, f"global_step_{iteration}"), int(iteration)
158175
else:
159176
raise FileNotFoundError(f"No iteration file found in {checkpoint_path}")
160177
else:
161178
# load specific iteration checkpoint
162-
return os.path.join(checkpoint_path, f"global_step_{step_num}")
179+
return os.path.join(checkpoint_path, f"global_step_{step_num}"), step_num
163180

164181

165182
# copy from verl/scripts/model_merger.py

trinity/common/models/vllm_async_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(
102102
else:
103103
self.action_mask_method = tokenize_and_mask_messages_hf
104104
self.state_dict_meta = None
105-
self.ckp_version = 0 # TODO: resume the value from the checkpoint
105+
self.model_version = 0 # TODO: resume the value from the checkpoint
106106
self.api_server_host = None
107107
self.api_server_port = None
108108

@@ -266,13 +266,15 @@ async def _collective_rpc(
266266
method, timeout, args, kwargs
267267
)
268268

269-
async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
269+
async def sync_model(
270+
self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None
271+
) -> bool:
270272
"""Sync model weights to vLLM."""
271273
if update_weight_args_list is not None:
272274
await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
273275
await self._collective_rpc("update_weight")
274276
self.logger.info("Sync model weights to vLLM successfully.")
275-
self.ckp_version += 1
277+
self.model_version = model_version
276278
return True
277279

278280
async def init_process_group(
@@ -352,8 +354,8 @@ async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]:
352354
async def reset_prefix_cache(self) -> None:
353355
await self.async_llm.reset_prefix_cache()
354356

355-
def get_ckp_version(self) -> int:
356-
return self.ckp_version
357+
def get_model_version(self) -> int:
358+
return self.model_version
357359

358360
async def sleep(self, level: int = 1) -> None:
359361
await self.async_llm.sleep(level=level)

trinity/common/models/vllm_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(self, config: InferenceModelConfig):
8787
self.action_mask_method = tokenize_and_mask_messages_hf
8888
self.lock = threading.Lock()
8989
self.state_dict_meta = None
90-
self.ckp_version = 0 # TODO: resume the value from the checkpoint
90+
self.model_version = 0 # TODO: resume the value from the checkpoint
9191

9292
def init_process_group(
9393
self,
@@ -278,14 +278,16 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
278278
def has_api_server(self) -> bool:
279279
return False
280280

281-
def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
281+
def sync_model(
282+
self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None
283+
) -> bool:
282284
"""Sync model weights to vLLM."""
283285
if update_weight_args_list is not None:
284286
self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
285287
self._collective_rpc("update_weight")
286288
self.logger.info("Sync model weights to vLLM successfully.")
287-
self.ckp_version += 1
289+
self.model_version = model_version
288290
return True
289291

290-
def get_ckp_version(self) -> int:
291-
return self.ckp_version
292+
def get_model_version(self) -> int:
293+
return self.model_version

trinity/explorer/explorer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _init_runner_pool(self) -> RunnerPool:
125125
self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners")
126126
return RunnerPool(self.config, self.models, self.auxiliary_models)
127127

128-
async def _update_model_weight(self, state_dict: dict) -> None:
128+
async def _update_model_weight(self, step_num: int, state_dict: dict) -> None:
129129
# TODO: update model weight
130130
self.state_dict = state_dict
131131
if self.state_dict_meta is None:
@@ -136,29 +136,31 @@ async def _update_model_weight(self, state_dict: dict) -> None:
136136
else:
137137
update_weight_args_list = None
138138
await asyncio.gather(
139-
*[model.sync_model.remote(update_weight_args_list) for model in self.models]
139+
*[model.sync_model.remote(step_num, update_weight_args_list) for model in self.models]
140140
)
141141
self.state_dict.clear()
142142

143143
async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
144144
# TODO: support more checkpoint types
145145
try:
146-
checkpoint_dir = get_checkpoint_dir_with_step_num(
146+
checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num(
147147
checkpoint_root_path=self.config.checkpoint_job_dir,
148148
trainer_type=self.config.trainer.trainer_type,
149149
step_num=step_num,
150150
)
151151
if checkpoint_dir == self.old_checkpoint:
152152
return
153153
model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor"))
154-
await self._update_model_weight(model_weights)
154+
await self._update_model_weight(checkpoint_step_num, model_weights)
155155
self.old_checkpoint = checkpoint_dir
156156
except Exception as e:
157157
self.logger.warning(f"Fail to load checkpoint: {e}")
158158

159159
async def _nccl_weights_update(self):
160160
assert self.state_dict_meta is not None
161-
await asyncio.gather(*[model.sync_model.remote() for model in self.models])
161+
await asyncio.gather(
162+
*[model.sync_model.remote(self.explore_step_num) for model in self.models]
163+
)
162164

163165
async def prepare(self) -> None:
164166
"""Preparation before running."""

trinity/explorer/workflow_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def run_task(self, task: Task) -> Status:
8585

8686
if not hasattr(exp, "info") or exp.info is None:
8787
exp.info = {}
88-
exp.info["model_version"] = self.model_wrapper.get_ckp_version()
88+
exp.info["model_version"] = self.model_wrapper.model_version
8989

9090
if not hasattr(exp, "metrics") or exp.metrics is None:
9191
exp.metrics = {}

0 commit comments

Comments
 (0)