Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
import time
import unittest
Expand All @@ -8,7 +9,7 @@

from tests.tools import get_unittest_dataset_config
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.common.config import StorageConfig, load_config
from trinity.common.config import InferenceModelConfig, StorageConfig, load_config
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel
Expand All @@ -22,7 +23,7 @@
@WORKFLOWS.register_module("dummy_workflow")
class DummyWorkflow(Workflow):
def __init__(self, model, task, auxiliary_models):
super().__init__(model, task)
super().__init__(model, task, auxiliary_models)
self.error_type = task.task_desc
self.seconds = None
if "timeout" in self.error_type:
Expand All @@ -35,6 +36,8 @@ def run(self) -> List[Experience]:
raise ValueError("Exception occurred")
elif self.error_type == "exit":
exit(1)
elif self.error_type == "auxiliary_models":
assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2
return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)]


Expand All @@ -60,6 +63,34 @@ def init_process_group(
pass


@ray.remote
class DummyAuxiliaryModel(InferenceModel):
def sync_model(self, update_weight_args_list):
return True

def get_ckp_version(self):
return 0

def init_process_group(
self,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
group_name: str,
backend: str = "nccl",
timeout: int = 1200,
update_with_checkpoint: bool = True,
) -> None:
pass

def has_api_server(self) -> bool:
return True

def api_server_ready(self) -> str:
return "http://localhosts:12345"


class RunnerPoolTest(unittest.TestCase):
def setUp(self):
ray.init(ignore_reinit_error=True)
Expand Down Expand Up @@ -184,3 +215,43 @@ def test_runner_pool(self):
exps = self.queue.read()
self.assertEqual(len(exps), 2) # `timeout_2` and `success`
self.assertEqual(len(pool._idle_actors), self.config.explorer.runner_num)

def test_runner_pool_with_auxiliary_models(self):
config = copy.deepcopy(self.config)
config.explorer.auxiliary_models = [
InferenceModelConfig(
engine_num=1,
),
InferenceModelConfig(
engine_num=1,
),
]
pool = RunnerPool(
config,
[DummyModel.remote(), DummyModel.remote()],
[[DummyAuxiliaryModel.remote()], [DummyAuxiliaryModel.remote()]],
)
taskset_config = get_unittest_dataset_config("countdown")
tasks = [
Task(
workflow=DummyWorkflow,
format_args=taskset_config.format,
rollout_args=taskset_config.rollout_args,
is_eval=False,
raw_task={
taskset_config.format.prompt_key: "auxiliary_models",
},
),
]

pool.run_tasks(
tasks=tasks,
)

# `auxiliary_models`
st = time.time()
status = pool.get_next_unorder()
et = time.time()
self.assertTrue(et - st < 1)
self.assertEqual(len(status), 1)
self.assertTrue(status[0].ok)
11 changes: 7 additions & 4 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def allocate(self, num: int) -> list:

def create_inference_models(
config: Config,
) -> Tuple[List[InferenceModel], List[InferenceModel]]:
) -> Tuple[List[InferenceModel], List[List[InferenceModel]]]:
"""Create `engine_num` rollout models.

Each model has `tensor_parallel_size` workers.
Expand Down Expand Up @@ -116,11 +116,12 @@ def create_inference_models(

# create auxiliary models
for model_config in config.explorer.auxiliary_models:
engines = []
for _ in range(model_config.engine_num):
bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size)
model_config.enable_openai_api = True
model_config.engine_type = "vllm_async"
auxiliary_engines.append(
engines.append(
ray.remote(vLLMAysncRolloutModel)
.options(
num_cpus=0,
Expand All @@ -132,8 +133,10 @@ def create_inference_models(
)
.remote(config=model_config)
)
auxiliary_engines.append(engines)
# all auxiliary engines run api server
for engine in auxiliary_engines:
engine.run_api_server.remote()
for engines in auxiliary_engines:
for engine in engines:
engine.run_api_server.remote()

return rollout_engines, auxiliary_engines
2 changes: 1 addition & 1 deletion trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _init_runner_pool(self) -> RunnerPool:
f"Number of Runners is less than number of models, set to {self.config.explorer.runner_num}"
)
self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners")
return RunnerPool(self.config, self.models)
return RunnerPool(self.config, self.models, self.auxiliary_models)

def _update_model_weight(self, state_dict: dict) -> None:
# TODO: update model weight
Expand Down
24 changes: 21 additions & 3 deletions trinity/explorer/runner_pool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Runner pool for running tasks in parallel. Modified from ray.util.actor_pool.ActorPool."""
import random
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import ray

from trinity.common.config import Config
from trinity.common.models.model import InferenceModel
from trinity.common.workflows import Task
from trinity.explorer.workflow_runner import Status, WorkflowRunner
from trinity.utils.log import get_logger
Expand All @@ -19,11 +20,17 @@ class RunnerPool:
`config.explorer.max_timeout`.
"""

def __init__(self, config: Config, models: List):
def __init__(
self,
config: Config,
models: List[InferenceModel],
auxiliary_models: Optional[List[List[InferenceModel]]] = None,
):
# actors to be used
self.logger = get_logger(__name__)
self.config = config
self.models = models
self.auxiliary_models = auxiliary_models or []
self.timeout = config.explorer.max_timeout
self.max_retry_times = config.explorer.max_retry_times

Expand All @@ -44,6 +51,9 @@ def __init__(self, config: Config, models: List):

# create new actors
self.engine_status = [0] * config.explorer.rollout_model.engine_num
self.auxiliary_engine_status_list = [
[0] * cfg.engine_num for cfg in config.explorer.auxiliary_models
]
self._idle_actors = list()
self.actor_to_engine_index = {}
self._create_actors(config.explorer.runner_num)
Expand All @@ -52,7 +62,15 @@ def _create_actors(self, num: int = 1):
new_actors = []
for _ in range(num):
engine_index = self.engine_status.index(min(self.engine_status))
new_actor = WorkflowRunner.remote(self.config, self.models[engine_index])
selected_auxiliary_models = [
models[engine_status.index(min(engine_status))]
for models, engine_status in zip(
self.auxiliary_models, self.auxiliary_engine_status_list
)
]
new_actor = WorkflowRunner.remote(
self.config, self.models[engine_index], selected_auxiliary_models
)
new_actors.append(new_actor)
self.engine_status[engine_index] += 1
self.actor_to_engine_index[new_actor] = engine_index
Expand Down