Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def run(self) -> List[Experience]:

@ray.remote
class DummyModel(InferenceModel):
def sync_model(self, update_weight_args_list):
def sync_model(self, model_version, update_weight_args_list):
return True

def get_ckp_version(self):
def get_model_version(self):
return 0

def init_process_group(
Expand All @@ -65,10 +65,10 @@ def init_process_group(

@ray.remote
class DummyAuxiliaryModel(InferenceModel):
def sync_model(self, update_weight_args_list):
def sync_model(self, model_version, update_weight_args_list):
return True

def get_ckp_version(self):
def get_model_version(self):
return 0

def init_process_group(
Expand Down
25 changes: 12 additions & 13 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for trainer."""

import multiprocessing
import os
import shutil
Expand Down Expand Up @@ -83,14 +84,12 @@ def test_trainer(self):
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
ray.shutdown(_exiting_interpreter=True)
# check checkpoint
from trinity.common.models.utils import get_checkpoint_dir_with_step_num

checkpoint_step_4 = get_checkpoint_dir_with_step_num(
checkpoint_step_4, _ = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=4,
)
checkpoint_step_8 = get_checkpoint_dir_with_step_num(
checkpoint_step_8, _ = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=8,
Expand Down Expand Up @@ -158,13 +157,12 @@ def test_trainer(self):
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
ray.shutdown(_exiting_interpreter=True)
# check checkpoint
from trinity.common.models.utils import get_checkpoint_dir_with_step_num

checkpoint_step_4 = get_checkpoint_dir_with_step_num(
checkpoint_step_4, step_num = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=4,
)
self.assertEqual(step_num, 4)
self.assertTrue(os.path.exists(checkpoint_step_4))

def tearDown(self):
Expand Down Expand Up @@ -374,19 +372,20 @@ def test_fully_async_mode(self):
explorer2_cache = CacheManager(explorer2_config)
cache = explorer2_cache.load_explorer()
self.assertEqual(cache["latest_iteration"], 4)
self.assertIsNotNone(
# check the lastest checkpoint
self.assertEqual(
get_checkpoint_dir_with_step_num(
checkpoint_root_path=explorer1_config.checkpoint_job_dir,
trainer_type="verl",
step_num=8,
)
)[1],
8,
)
self.assertIsNotNone(
self.assertEqual(
get_checkpoint_dir_with_step_num(
checkpoint_root_path=explorer2_config.checkpoint_job_dir,
trainer_type="verl",
step_num=8,
)
)[1],
8,
)
ray.shutdown()

Expand Down
8 changes: 5 additions & 3 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex
raise NotImplementedError

@abstractmethod
def get_ckp_version(self) -> int:
def get_model_version(self) -> int:
"""Get the checkpoint version."""

def get_available_address(self) -> Tuple[str, int]:
Expand Down Expand Up @@ -99,8 +99,10 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
else:
return ray.get(self.model.convert_messages_to_experience.remote(messages))

def get_ckp_version(self) -> int:
return ray.get(self.model.get_ckp_version.remote())
@property
def model_version(self) -> int:
"""Get the version of the model."""
return ray.get(self.model.get_model_version.remote())

def get_openai_client(self) -> openai.OpenAI:
"""Get the openai client.
Expand Down
31 changes: 24 additions & 7 deletions trinity/common/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,21 @@ def get_checkpoint_dir_with_step_num(
checkpoint_root_path: str,
trainer_type: str = "verl",
step_num: Optional[int] = None,
) -> str:
) -> Tuple[str, int]:
"""Get the checkpoint directory from a root checkpoint directory.

Args:
checkpoint_root_path (str): The root checkpoint directory.
trainer_type (str): The trainer type. Only support "verl" for now.
step_num (Optional[int], optional): The step number. Defaults to None.
step_num (Optional[int], optional): The step number. If specified,
load the checkpoint with the specified step number. If None,
load the latest checkpoint. Defaults to None.

Returns:
Tuple[str, int]: The checkpoint directory and the step number of the checkpoint.
"""
if trainer_type == "verl":
return get_verl_checkpoint_dir(checkpoint_path=checkpoint_root_path, step_num=step_num)
return get_verl_checkpoint_info(checkpoint_path=checkpoint_root_path, step_num=step_num)
else:
raise NotImplementedError(f"Unsupported trainer type {trainer_type}")

Expand Down Expand Up @@ -144,8 +149,20 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
raise ValueError(f"Unsupported placement: {placement}")


def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None) -> str:
"""Get the checkpoint directory from a Verl root checkpoint directory."""
def get_verl_checkpoint_info(
checkpoint_path: str, step_num: Optional[int] = None
) -> Tuple[str, int]:
"""Get the checkpoint directory from a Verl root checkpoint directory.

Args:
checkpoint_path (str): The root checkpoint directory.
step_num (Optional[int], optional): The step number. If specified,
load the checkpoint with the specified step number. If None,
load the latest checkpoint. Defaults to None.

Returns:
Tuple[str, int]: The checkpoint directory and the step number of the checkpoint.
"""
if step_num is None:
# load latest checkpoint
iteration_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt")
Expand All @@ -154,12 +171,12 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None
iteration_file, "r", encoding="utf-8"
) as f: # TODO: this file may be modified simultaneously
iteration = f.read().strip()
return os.path.join(checkpoint_path, f"global_step_{iteration}")
return os.path.join(checkpoint_path, f"global_step_{iteration}"), int(iteration)
else:
raise FileNotFoundError(f"No iteration file found in {checkpoint_path}")
else:
# load specific iteration checkpoint
return os.path.join(checkpoint_path, f"global_step_{step_num}")
return os.path.join(checkpoint_path, f"global_step_{step_num}"), step_num


# copy from verl/scripts/model_merger.py
Expand Down
12 changes: 7 additions & 5 deletions trinity/common/models/vllm_async_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
else:
self.action_mask_method = tokenize_and_mask_messages_hf
self.state_dict_meta = None
self.ckp_version = 0 # TODO: resume the value from the checkpoint
self.model_version = 0 # TODO: resume the value from the checkpoint
self.api_server_host = None
self.api_server_port = None

Expand Down Expand Up @@ -266,13 +266,15 @@ async def _collective_rpc(
method, timeout, args, kwargs
)

async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
async def sync_model(
self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None
) -> bool:
"""Sync model weights to vLLM."""
if update_weight_args_list is not None:
await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
await self._collective_rpc("update_weight")
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
self.model_version = model_version
return True

async def init_process_group(
Expand Down Expand Up @@ -352,8 +354,8 @@ async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]:
async def reset_prefix_cache(self) -> None:
await self.async_llm.reset_prefix_cache()

def get_ckp_version(self) -> int:
return self.ckp_version
def get_model_version(self) -> int:
return self.model_version

async def sleep(self, level: int = 1) -> None:
await self.async_llm.sleep(level=level)
Expand Down
12 changes: 7 additions & 5 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, config: InferenceModelConfig):
self.action_mask_method = tokenize_and_mask_messages_hf
self.lock = threading.Lock()
self.state_dict_meta = None
self.ckp_version = 0 # TODO: resume the value from the checkpoint
self.model_version = 0 # TODO: resume the value from the checkpoint

def init_process_group(
self,
Expand Down Expand Up @@ -278,14 +278,16 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
def has_api_server(self) -> bool:
return False

def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
def sync_model(
self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None
) -> bool:
"""Sync model weights to vLLM."""
if update_weight_args_list is not None:
self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
self._collective_rpc("update_weight")
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
self.model_version += model_version
return True

def get_ckp_version(self) -> int:
return self.ckp_version
def get_model_version(self) -> int:
return self.model_version
12 changes: 7 additions & 5 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _init_runner_pool(self) -> RunnerPool:
self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners")
return RunnerPool(self.config, self.models, self.auxiliary_models)

async def _update_model_weight(self, state_dict: dict) -> None:
async def _update_model_weight(self, step_num: int, state_dict: dict) -> None:
# TODO: update model weight
self.state_dict = state_dict
if self.state_dict_meta is None:
Expand All @@ -135,29 +135,31 @@ async def _update_model_weight(self, state_dict: dict) -> None:
else:
update_weight_args_list = None
await asyncio.gather(
*[model.sync_model.remote(update_weight_args_list) for model in self.models]
*[model.sync_model.remote(step_num, update_weight_args_list) for model in self.models]
)
self.state_dict.clear()

async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
# TODO: support more checkpoint types
try:
checkpoint_dir = get_checkpoint_dir_with_step_num(
checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=step_num,
)
if checkpoint_dir == self.old_checkpoint:
return
model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor"))
await self._update_model_weight(model_weights)
await self._update_model_weight(checkpoint_step_num, model_weights)
self.old_checkpoint = checkpoint_dir
except Exception as e:
self.logger.warning(f"Fail to load checkpoint: {e}")

async def _nccl_weights_update(self):
assert self.state_dict_meta is not None
await asyncio.gather(*[model.sync_model.remote() for model in self.models])
await asyncio.gather(
*[model.sync_model.remote(self.explore_step_num) for model in self.models]
)

async def prepare(self) -> None:
"""Preparation before running."""
Expand Down
2 changes: 1 addition & 1 deletion trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run_task(self, task: Task) -> Status:

if not hasattr(exp, "info") or exp.info is None:
exp.info = {}
exp.info["model_version"] = self.model_wrapper.get_ckp_version()
exp.info["model_version"] = self.model_wrapper.model_version

if not hasattr(exp, "metrics") or exp.metrics is None:
exp.metrics = {}
Expand Down