Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def run(self) -> List[Experience]:

@ray.remote
class DummyModel(InferenceModel):
def sync_model(self, update_weight_args_list):
def sync_model(self, model_version, update_weight_args_list):
return True

def get_ckp_version(self):
def get_model_version(self):
return 0

def init_process_group(
Expand All @@ -65,10 +65,10 @@ def init_process_group(

@ray.remote
class DummyAuxiliaryModel(InferenceModel):
def sync_model(self, update_weight_args_list):
def sync_model(self, model_version, update_weight_args_list):
return True

def get_ckp_version(self):
def get_model_version(self):
return 0

def init_process_group(
Expand Down
25 changes: 12 additions & 13 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for trainer."""

import multiprocessing
import os
import shutil
Expand Down Expand Up @@ -83,14 +84,12 @@ def test_trainer(self):
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
ray.shutdown(_exiting_interpreter=True)
# check checkpoint
from trinity.common.models.utils import get_checkpoint_dir_with_step_num

checkpoint_step_4 = get_checkpoint_dir_with_step_num(
checkpoint_step_4, _ = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=4,
)
checkpoint_step_8 = get_checkpoint_dir_with_step_num(
checkpoint_step_8, _ = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=8,
Expand Down Expand Up @@ -158,13 +157,12 @@ def test_trainer(self):
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
ray.shutdown(_exiting_interpreter=True)
# check checkpoint
from trinity.common.models.utils import get_checkpoint_dir_with_step_num

checkpoint_step_4 = get_checkpoint_dir_with_step_num(
checkpoint_step_4, step_num = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=4,
)
self.assertEqual(step_num, 4)
self.assertTrue(os.path.exists(checkpoint_step_4))

def tearDown(self):
Expand Down Expand Up @@ -374,19 +372,20 @@ def test_fully_async_mode(self):
explorer2_cache = CacheManager(explorer2_config)
cache = explorer2_cache.load_explorer()
self.assertEqual(cache["latest_iteration"], 4)
self.assertIsNotNone(
# check the lastest checkpoint
self.assertEqual(
get_checkpoint_dir_with_step_num(
checkpoint_root_path=explorer1_config.checkpoint_job_dir,
trainer_type="verl",
step_num=8,
)
)[1],
8,
)
self.assertIsNotNone(
self.assertEqual(
get_checkpoint_dir_with_step_num(
checkpoint_root_path=explorer2_config.checkpoint_job_dir,
trainer_type="verl",
step_num=8,
)
)[1],
8,
)
ray.shutdown()

Expand Down
8 changes: 5 additions & 3 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex
raise NotImplementedError

@abstractmethod
def get_ckp_version(self) -> int:
def get_model_version(self) -> int:
"""Get the checkpoint version."""

def get_available_address(self) -> Tuple[str, int]:
Expand Down Expand Up @@ -99,8 +99,10 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
else:
return ray.get(self.model.convert_messages_to_experience.remote(messages))

def get_ckp_version(self) -> int:
return ray.get(self.model.get_ckp_version.remote())
@property
def model_version(self) -> int:
"""Get the version of the model."""
return ray.get(self.model.get_model_version.remote())

def get_openai_client(self) -> openai.OpenAI:
"""Get the openai client.
Expand Down
31 changes: 24 additions & 7 deletions trinity/common/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,21 @@ def get_checkpoint_dir_with_step_num(
checkpoint_root_path: str,
trainer_type: str = "verl",
step_num: Optional[int] = None,
) -> str:
) -> Tuple[str, int]:
"""Get the checkpoint directory from a root checkpoint directory.

Args:
checkpoint_root_path (str): The root checkpoint directory.
trainer_type (str): The trainer type. Only support "verl" for now.
step_num (Optional[int], optional): The step number. Defaults to None.
step_num (Optional[int], optional): The step number. If specified,
load the checkpoint with the specified step number. If None,
load the latest checkpoint. Defaults to None.

Returns:
Tuple[str, int]: The checkpoint directory and the step number of the checkpoint.
"""
if trainer_type == "verl":
return get_verl_checkpoint_dir(checkpoint_path=checkpoint_root_path, step_num=step_num)
return get_verl_checkpoint_info(checkpoint_path=checkpoint_root_path, step_num=step_num)
else:
raise NotImplementedError(f"Unsupported trainer type {trainer_type}")

Expand Down Expand Up @@ -144,8 +149,20 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
raise ValueError(f"Unsupported placement: {placement}")


def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None) -> str:
"""Get the checkpoint directory from a Verl root checkpoint directory."""
def get_verl_checkpoint_info(
checkpoint_path: str, step_num: Optional[int] = None
) -> Tuple[str, int]:
"""Get the checkpoint directory from a Verl root checkpoint directory.

Args:
checkpoint_path (str): The root checkpoint directory.
step_num (Optional[int], optional): The step number. If specified,
load the checkpoint with the specified step number. If None,
load the latest checkpoint. Defaults to None.

Returns:
Tuple[str, int]: The checkpoint directory and the step number of the checkpoint.
"""
if step_num is None:
# load latest checkpoint
iteration_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt")
Expand All @@ -154,12 +171,12 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None
iteration_file, "r", encoding="utf-8"
) as f: # TODO: this file may be modified simultaneously
iteration = f.read().strip()
return os.path.join(checkpoint_path, f"global_step_{iteration}")
return os.path.join(checkpoint_path, f"global_step_{iteration}"), int(iteration)
else:
raise FileNotFoundError(f"No iteration file found in {checkpoint_path}")
else:
# load specific iteration checkpoint
return os.path.join(checkpoint_path, f"global_step_{step_num}")
return os.path.join(checkpoint_path, f"global_step_{step_num}"), step_num


# copy from verl/scripts/model_merger.py
Expand Down
12 changes: 7 additions & 5 deletions trinity/common/models/vllm_async_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
else:
self.action_mask_method = tokenize_and_mask_messages_hf
self.state_dict_meta = None
self.ckp_version = 0 # TODO: resume the value from the checkpoint
self.model_version = 0 # TODO: resume the value from the checkpoint
self.api_server_host = None
self.api_server_port = None

Expand Down Expand Up @@ -266,13 +266,15 @@ async def _collective_rpc(
method, timeout, args, kwargs
)

async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
async def sync_model(
self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None
) -> bool:
"""Sync model weights to vLLM."""
if update_weight_args_list is not None:
await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
await self._collective_rpc("update_weight")
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
self.model_version = model_version
return True

async def init_process_group(
Expand Down Expand Up @@ -352,8 +354,8 @@ async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]:
async def reset_prefix_cache(self) -> None:
await self.async_llm.reset_prefix_cache()

def get_ckp_version(self) -> int:
return self.ckp_version
def get_model_version(self) -> int:
return self.model_version

async def sleep(self, level: int = 1) -> None:
await self.async_llm.sleep(level=level)
Expand Down
12 changes: 7 additions & 5 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, config: InferenceModelConfig):
self.action_mask_method = tokenize_and_mask_messages_hf
self.lock = threading.Lock()
self.state_dict_meta = None
self.ckp_version = 0 # TODO: resume the value from the checkpoint
self.model_version = 0 # TODO: resume the value from the checkpoint

def init_process_group(
self,
Expand Down Expand Up @@ -278,14 +278,16 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
def has_api_server(self) -> bool:
return False

def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
def sync_model(
self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None
) -> bool:
"""Sync model weights to vLLM."""
if update_weight_args_list is not None:
self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
self._collective_rpc("update_weight")
self.logger.info("Sync model weights to vLLM successfully.")
self.ckp_version += 1
self.model_version = model_version
return True

def get_ckp_version(self) -> int:
return self.ckp_version
def get_model_version(self) -> int:
return self.model_version
12 changes: 7 additions & 5 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _init_runner_pool(self) -> RunnerPool:
self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners")
return RunnerPool(self.config, self.models, self.auxiliary_models)

async def _update_model_weight(self, state_dict: dict) -> None:
async def _update_model_weight(self, step_num: int, state_dict: dict) -> None:
# TODO: update model weight
self.state_dict = state_dict
if self.state_dict_meta is None:
Expand All @@ -135,29 +135,31 @@ async def _update_model_weight(self, state_dict: dict) -> None:
else:
update_weight_args_list = None
await asyncio.gather(
*[model.sync_model.remote(update_weight_args_list) for model in self.models]
*[model.sync_model.remote(step_num, update_weight_args_list) for model in self.models]
)
self.state_dict.clear()

async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None:
# TODO: support more checkpoint types
try:
checkpoint_dir = get_checkpoint_dir_with_step_num(
checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
step_num=step_num,
)
if checkpoint_dir == self.old_checkpoint:
return
model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor"))
await self._update_model_weight(model_weights)
await self._update_model_weight(checkpoint_step_num, model_weights)
self.old_checkpoint = checkpoint_dir
except Exception as e:
self.logger.warning(f"Fail to load checkpoint: {e}")

async def _nccl_weights_update(self):
assert self.state_dict_meta is not None
await asyncio.gather(*[model.sync_model.remote() for model in self.models])
await asyncio.gather(
*[model.sync_model.remote(self.explore_step_num) for model in self.models]
)

async def prepare(self) -> None:
"""Preparation before running."""
Expand Down
2 changes: 1 addition & 1 deletion trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run_task(self, task: Task) -> Status:

if not hasattr(exp, "info") or exp.info is None:
exp.info = {}
exp.info["model_version"] = self.model_wrapper.get_ckp_version()
exp.info["model_version"] = self.model_wrapper.model_version

if not hasattr(exp, "metrics") or exp.metrics is None:
exp.metrics = {}
Expand Down