diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 1dc14ded69..036339e747 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -1,3 +1,4 @@ +import copy import os import time import unittest @@ -8,7 +9,7 @@ from tests.tools import get_unittest_dataset_config from trinity.buffer.reader.queue_reader import QueueReader -from trinity.common.config import StorageConfig, load_config +from trinity.common.config import InferenceModelConfig, StorageConfig, load_config from trinity.common.constants import AlgorithmType, StorageType from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel @@ -22,7 +23,7 @@ @WORKFLOWS.register_module("dummy_workflow") class DummyWorkflow(Workflow): def __init__(self, model, task, auxiliary_models): - super().__init__(model, task) + super().__init__(model, task, auxiliary_models) self.error_type = task.task_desc self.seconds = None if "timeout" in self.error_type: @@ -35,6 +36,8 @@ def run(self) -> List[Experience]: raise ValueError("Exception occurred") elif self.error_type == "exit": exit(1) + elif self.error_type == "auxiliary_models": + assert self.auxiliary_models is not None and len(self.auxiliary_models) == 2 return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)] @@ -60,6 +63,34 @@ def init_process_group( pass +@ray.remote +class DummyAuxiliaryModel(InferenceModel): + def sync_model(self, update_weight_args_list): + return True + + def get_ckp_version(self): + return 0 + + def init_process_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + timeout: int = 1200, + update_with_checkpoint: bool = True, + ) -> None: + pass + + def has_api_server(self) -> bool: + return True + + def api_server_ready(self) -> str: + return "http://localhosts:12345" + + class RunnerPoolTest(unittest.TestCase): def setUp(self): ray.init(ignore_reinit_error=True) @@ -184,3 +215,43 @@ def test_runner_pool(self): exps = self.queue.read() self.assertEqual(len(exps), 2) # `timeout_2` and `success` self.assertEqual(len(pool._idle_actors), self.config.explorer.runner_num) + + def test_runner_pool_with_auxiliary_models(self): + config = copy.deepcopy(self.config) + config.explorer.auxiliary_models = [ + InferenceModelConfig( + engine_num=1, + ), + InferenceModelConfig( + engine_num=1, + ), + ] + pool = RunnerPool( + config, + [DummyModel.remote(), DummyModel.remote()], + [[DummyAuxiliaryModel.remote()], [DummyAuxiliaryModel.remote()]], + ) + taskset_config = get_unittest_dataset_config("countdown") + tasks = [ + Task( + workflow=DummyWorkflow, + format_args=taskset_config.format, + rollout_args=taskset_config.rollout_args, + is_eval=False, + raw_task={ + taskset_config.format.prompt_key: "auxiliary_models", + }, + ), + ] + + pool.run_tasks( + tasks=tasks, + ) + + # `auxiliary_models` + st = time.time() + status = pool.get_next_unorder() + et = time.time() + self.assertTrue(et - st < 1) + self.assertEqual(len(status), 1) + self.assertTrue(status[0].ok) diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 7324ff2a47..25cb927799 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -34,7 +34,7 @@ def allocate(self, num: int) -> list: def create_inference_models( config: Config, -) -> Tuple[List[InferenceModel], List[InferenceModel]]: +) -> Tuple[List[InferenceModel], List[List[InferenceModel]]]: """Create `engine_num` rollout models. Each model has `tensor_parallel_size` workers. @@ -116,11 +116,12 @@ def create_inference_models( # create auxiliary models for model_config in config.explorer.auxiliary_models: + engines = [] for _ in range(model_config.engine_num): bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) model_config.enable_openai_api = True model_config.engine_type = "vllm_async" - auxiliary_engines.append( + engines.append( ray.remote(vLLMAysncRolloutModel) .options( num_cpus=0, @@ -132,8 +133,10 @@ def create_inference_models( ) .remote(config=model_config) ) + auxiliary_engines.append(engines) # all auxiliary engines run api server - for engine in auxiliary_engines: - engine.run_api_server.remote() + for engines in auxiliary_engines: + for engine in engines: + engine.run_api_server.remote() return rollout_engines, auxiliary_engines diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index f5f3466675..9c3cc414c7 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -114,7 +114,7 @@ def _init_runner_pool(self) -> RunnerPool: f"Number of Runners is less than number of models, set to {self.config.explorer.runner_num}" ) self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners") - return RunnerPool(self.config, self.models) + return RunnerPool(self.config, self.models, self.auxiliary_models) def _update_model_weight(self, state_dict: dict) -> None: # TODO: update model weight diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py index e58e87c124..761d95965b 100644 --- a/trinity/explorer/runner_pool.py +++ b/trinity/explorer/runner_pool.py @@ -1,10 +1,11 @@ """Runner pool for running tasks in parallel. Modified from ray.util.actor_pool.ActorPool.""" import random -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import ray from trinity.common.config import Config +from trinity.common.models.model import InferenceModel from trinity.common.workflows import Task from trinity.explorer.workflow_runner import Status, WorkflowRunner from trinity.utils.log import get_logger @@ -19,11 +20,17 @@ class RunnerPool: `config.explorer.max_timeout`. """ - def __init__(self, config: Config, models: List): + def __init__( + self, + config: Config, + models: List[InferenceModel], + auxiliary_models: Optional[List[List[InferenceModel]]] = None, + ): # actors to be used self.logger = get_logger(__name__) self.config = config self.models = models + self.auxiliary_models = auxiliary_models or [] self.timeout = config.explorer.max_timeout self.max_retry_times = config.explorer.max_retry_times @@ -44,6 +51,9 @@ def __init__(self, config: Config, models: List): # create new actors self.engine_status = [0] * config.explorer.rollout_model.engine_num + self.auxiliary_engine_status_list = [ + [0] * cfg.engine_num for cfg in config.explorer.auxiliary_models + ] self._idle_actors = list() self.actor_to_engine_index = {} self._create_actors(config.explorer.runner_num) @@ -52,7 +62,15 @@ def _create_actors(self, num: int = 1): new_actors = [] for _ in range(num): engine_index = self.engine_status.index(min(self.engine_status)) - new_actor = WorkflowRunner.remote(self.config, self.models[engine_index]) + selected_auxiliary_models = [ + models[engine_status.index(min(engine_status))] + for models, engine_status in zip( + self.auxiliary_models, self.auxiliary_engine_status_list + ) + ] + new_actor = WorkflowRunner.remote( + self.config, self.models[engine_index], selected_auxiliary_models + ) new_actors.append(new_actor) self.engine_status[engine_index] += 1 self.actor_to_engine_index[new_actor] = engine_index