Skip to content

Commit 5e82f55

Browse files
committed
bug fix
1 parent 846b2bf commit 5e82f55

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

tests/explorer/runner_pool_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def run(self) -> List[Experience]:
3737
elif self.error_type == "exit":
3838
exit(1)
3939
elif self.error_type == "auxiliary_models":
40-
assert self.auxiliary_models is not None and len(self.auxiliary_models) > 0
40+
assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2
4141
return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)]
4242

4343

@@ -221,12 +221,15 @@ def test_runner_pool_with_auxiliary_models(self):
221221
config.explorer.auxiliary_models = [
222222
InferenceModelConfig(
223223
engine_num=1,
224-
)
224+
),
225+
InferenceModelConfig(
226+
engine_num=1,
227+
),
225228
]
226229
pool = RunnerPool(
227230
config,
228231
[DummyModel.remote(), DummyModel.remote()],
229-
[DummyAuxiliaryModel.remote(), DummyAuxiliaryModel.remote()],
232+
[[DummyAuxiliaryModel.remote()], [DummyAuxiliaryModel.remote()]],
230233
)
231234
taskset_config = get_unittest_dataset_config("countdown")
232235
tasks = [

trinity/common/models/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def allocate(self, num: int) -> list:
3434

3535
def create_inference_models(
3636
config: Config,
37-
) -> Tuple[List[InferenceModel], List[InferenceModel]]:
37+
) -> Tuple[List[InferenceModel], List[List[InferenceModel]]]:
3838
"""Create `engine_num` rollout models.
3939
4040
Each model has `tensor_parallel_size` workers.
@@ -116,11 +116,12 @@ def create_inference_models(
116116

117117
# create auxiliary models
118118
for model_config in config.explorer.auxiliary_models:
119+
engines = []
119120
for _ in range(model_config.engine_num):
120121
bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size)
121122
model_config.enable_openai_api = True
122123
model_config.engine_type = "vllm_async"
123-
auxiliary_engines.append(
124+
engines.append(
124125
ray.remote(vLLMAysncRolloutModel)
125126
.options(
126127
num_cpus=0,
@@ -132,8 +133,10 @@ def create_inference_models(
132133
)
133134
.remote(config=model_config)
134135
)
136+
auxiliary_engines.append(engines)
135137
# all auxiliary engines run api server
136-
for engine in auxiliary_engines:
137-
engine.run_api_server.remote()
138+
for engines in auxiliary_engines:
139+
for engine in engines:
140+
engine.run_api_server.remote()
138141

139142
return rollout_engines, auxiliary_engines

trinity/explorer/runner_pool.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ray
66

77
from trinity.common.config import Config
8+
from trinity.common.models.model import InferenceModel
89
from trinity.common.workflows import Task
910
from trinity.explorer.workflow_runner import Status, WorkflowRunner
1011
from trinity.utils.log import get_logger
@@ -19,13 +20,17 @@ class RunnerPool:
1920
`config.explorer.max_timeout`.
2021
"""
2122

22-
def __init__(self, config: Config, models: List, auxiliary_models: Optional[List] = None):
23+
def __init__(
24+
self,
25+
config: Config,
26+
models: List[InferenceModel],
27+
auxiliary_models: Optional[List[List[InferenceModel]]] = None,
28+
):
2329
# actors to be used
2430
self.logger = get_logger(__name__)
2531
self.config = config
2632
self.models = models
2733
self.auxiliary_models = auxiliary_models or []
28-
self.auxiliary_models = [self.auxiliary_models] # TODO: support multiple auxiliary models
2934
self.timeout = config.explorer.max_timeout
3035
self.max_retry_times = config.explorer.max_retry_times
3136

@@ -47,10 +52,7 @@ def __init__(self, config: Config, models: List, auxiliary_models: Optional[List
4752
# create new actors
4853
self.engine_status = [0] * config.explorer.rollout_model.engine_num
4954
self.auxiliary_engine_status_list = [
50-
[0]
51-
* len(self.auxiliary_models[0])
52-
# TODO: support multiple auxiliary models
53-
# [0] * cfg.engine_num for cfg in config.explorer.auxiliary_models
55+
[0] * cfg.engine_num for cfg in config.explorer.auxiliary_models
5456
]
5557
self._idle_actors = list()
5658
self.actor_to_engine_index = {}

0 commit comments

Comments
 (0)