Skip to content

Commit 318da40

Browse files
authored
Bug fix in auxiliary_models. (#55)
1 parent 7bddb32 commit 318da40

File tree

4 files changed

+102
-10
lines changed

4 files changed

+102
-10
lines changed

tests/explorer/runner_pool_test.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import os
23
import time
34
import unittest
@@ -8,7 +9,7 @@
89

910
from tests.tools import get_unittest_dataset_config
1011
from trinity.buffer.reader.queue_reader import QueueReader
11-
from trinity.common.config import StorageConfig, load_config
12+
from trinity.common.config import InferenceModelConfig, StorageConfig, load_config
1213
from trinity.common.constants import AlgorithmType, StorageType
1314
from trinity.common.experience import Experience
1415
from trinity.common.models.model import InferenceModel
@@ -22,7 +23,7 @@
2223
@WORKFLOWS.register_module("dummy_workflow")
2324
class DummyWorkflow(Workflow):
2425
def __init__(self, model, task, auxiliary_models):
25-
super().__init__(model, task)
26+
super().__init__(model, task, auxiliary_models)
2627
self.error_type = task.task_desc
2728
self.seconds = None
2829
if "timeout" in self.error_type:
@@ -35,6 +36,8 @@ def run(self) -> List[Experience]:
3536
raise ValueError("Exception occurred")
3637
elif self.error_type == "exit":
3738
exit(1)
39+
elif self.error_type == "auxiliary_models":
40+
assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2
3841
return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)]
3942

4043

@@ -60,6 +63,34 @@ def init_process_group(
6063
pass
6164

6265

66+
@ray.remote
67+
class DummyAuxiliaryModel(InferenceModel):
68+
def sync_model(self, update_weight_args_list):
69+
return True
70+
71+
def get_ckp_version(self):
72+
return 0
73+
74+
def init_process_group(
75+
self,
76+
master_address: str,
77+
master_port: int,
78+
rank_offset: int,
79+
world_size: int,
80+
group_name: str,
81+
backend: str = "nccl",
82+
timeout: int = 1200,
83+
update_with_checkpoint: bool = True,
84+
) -> None:
85+
pass
86+
87+
def has_api_server(self) -> bool:
88+
return True
89+
90+
def api_server_ready(self) -> str:
91+
return "http://localhosts:12345"
92+
93+
6394
class RunnerPoolTest(unittest.TestCase):
6495
def setUp(self):
6596
ray.init(ignore_reinit_error=True)
@@ -184,3 +215,43 @@ def test_runner_pool(self):
184215
exps = self.queue.read()
185216
self.assertEqual(len(exps), 2) # `timeout_2` and `success`
186217
self.assertEqual(len(pool._idle_actors), self.config.explorer.runner_num)
218+
219+
def test_runner_pool_with_auxiliary_models(self):
220+
config = copy.deepcopy(self.config)
221+
config.explorer.auxiliary_models = [
222+
InferenceModelConfig(
223+
engine_num=1,
224+
),
225+
InferenceModelConfig(
226+
engine_num=1,
227+
),
228+
]
229+
pool = RunnerPool(
230+
config,
231+
[DummyModel.remote(), DummyModel.remote()],
232+
[[DummyAuxiliaryModel.remote()], [DummyAuxiliaryModel.remote()]],
233+
)
234+
taskset_config = get_unittest_dataset_config("countdown")
235+
tasks = [
236+
Task(
237+
workflow=DummyWorkflow,
238+
format_args=taskset_config.format,
239+
rollout_args=taskset_config.rollout_args,
240+
is_eval=False,
241+
raw_task={
242+
taskset_config.format.prompt_key: "auxiliary_models",
243+
},
244+
),
245+
]
246+
247+
pool.run_tasks(
248+
tasks=tasks,
249+
)
250+
251+
# `auxiliary_models`
252+
st = time.time()
253+
status = pool.get_next_unorder()
254+
et = time.time()
255+
self.assertTrue(et - st < 1)
256+
self.assertEqual(len(status), 1)
257+
self.assertTrue(status[0].ok)

trinity/common/models/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def allocate(self, num: int) -> list:
3434

3535
def create_inference_models(
3636
config: Config,
37-
) -> Tuple[List[InferenceModel], List[InferenceModel]]:
37+
) -> Tuple[List[InferenceModel], List[List[InferenceModel]]]:
3838
"""Create `engine_num` rollout models.
3939
4040
Each model has `tensor_parallel_size` workers.
@@ -116,11 +116,12 @@ def create_inference_models(
116116

117117
# create auxiliary models
118118
for model_config in config.explorer.auxiliary_models:
119+
engines = []
119120
for _ in range(model_config.engine_num):
120121
bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size)
121122
model_config.enable_openai_api = True
122123
model_config.engine_type = "vllm_async"
123-
auxiliary_engines.append(
124+
engines.append(
124125
ray.remote(vLLMAysncRolloutModel)
125126
.options(
126127
num_cpus=0,
@@ -132,8 +133,10 @@ def create_inference_models(
132133
)
133134
.remote(config=model_config)
134135
)
136+
auxiliary_engines.append(engines)
135137
# all auxiliary engines run api server
136-
for engine in auxiliary_engines:
137-
engine.run_api_server.remote()
138+
for engines in auxiliary_engines:
139+
for engine in engines:
140+
engine.run_api_server.remote()
138141

139142
return rollout_engines, auxiliary_engines

trinity/explorer/explorer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _init_runner_pool(self) -> RunnerPool:
114114
f"Number of Runners is less than number of models, set to {self.config.explorer.runner_num}"
115115
)
116116
self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners")
117-
return RunnerPool(self.config, self.models)
117+
return RunnerPool(self.config, self.models, self.auxiliary_models)
118118

119119
def _update_model_weight(self, state_dict: dict) -> None:
120120
# TODO: update model weight

trinity/explorer/runner_pool.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Runner pool for running tasks in parallel. Modified from ray.util.actor_pool.ActorPool."""
22
import random
3-
from typing import List, Tuple, Union
3+
from typing import List, Optional, Tuple, Union
44

55
import ray
66

77
from trinity.common.config import Config
8+
from trinity.common.models.model import InferenceModel
89
from trinity.common.workflows import Task
910
from trinity.explorer.workflow_runner import Status, WorkflowRunner
1011
from trinity.utils.log import get_logger
@@ -19,11 +20,17 @@ class RunnerPool:
1920
`config.explorer.max_timeout`.
2021
"""
2122

22-
def __init__(self, config: Config, models: List):
23+
def __init__(
24+
self,
25+
config: Config,
26+
models: List[InferenceModel],
27+
auxiliary_models: Optional[List[List[InferenceModel]]] = None,
28+
):
2329
# actors to be used
2430
self.logger = get_logger(__name__)
2531
self.config = config
2632
self.models = models
33+
self.auxiliary_models = auxiliary_models or []
2734
self.timeout = config.explorer.max_timeout
2835
self.max_retry_times = config.explorer.max_retry_times
2936

@@ -44,6 +51,9 @@ def __init__(self, config: Config, models: List):
4451

4552
# create new actors
4653
self.engine_status = [0] * config.explorer.rollout_model.engine_num
54+
self.auxiliary_engine_status_list = [
55+
[0] * cfg.engine_num for cfg in config.explorer.auxiliary_models
56+
]
4757
self._idle_actors = list()
4858
self.actor_to_engine_index = {}
4959
self._create_actors(config.explorer.runner_num)
@@ -52,7 +62,15 @@ def _create_actors(self, num: int = 1):
5262
new_actors = []
5363
for _ in range(num):
5464
engine_index = self.engine_status.index(min(self.engine_status))
55-
new_actor = WorkflowRunner.remote(self.config, self.models[engine_index])
65+
selected_auxiliary_models = [
66+
models[engine_status.index(min(engine_status))]
67+
for models, engine_status in zip(
68+
self.auxiliary_models, self.auxiliary_engine_status_list
69+
)
70+
]
71+
new_actor = WorkflowRunner.remote(
72+
self.config, self.models[engine_index], selected_auxiliary_models
73+
)
5674
new_actors.append(new_actor)
5775
self.engine_status[engine_index] += 1
5876
self.actor_to_engine_index[new_actor] = engine_index

0 commit comments

Comments
 (0)