diff --git a/docs/sphinx_doc/source/index.rst b/docs/sphinx_doc/source/index.rst index e25a3a321e..5604faa15d 100644 --- a/docs/sphinx_doc/source/index.rst +++ b/docs/sphinx_doc/source/index.rst @@ -18,6 +18,7 @@ Welcome to Trinity-RFT's documentation! tutorial/example_reasoning_basic.md tutorial/example_reasoning_advanced.md + tutorial/example_async_mode.md tutorial/example_multi_turn.md tutorial/example_dpo.md tutorial/example_data_functionalities.md diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index f1c82a9f25..dc8e7e676f 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -2,7 +2,9 @@ This guide will introduce how to add new task types to Trinity-RFT and provide relevant development guidelines. -> **Note**: Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code. +```{note} +Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code. +``` --- @@ -31,11 +33,11 @@ Before starting development, it's important to understand several core concepts: ### Step 1: Prepare Task Dataset -Each `Task` is a Python dictionary (`Dict[str, Any]`), containing various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario. +Each `Task` contains various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario. In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line’s JSON contains `question` and `answer` fields representing the problem description and standard answer, respectively. -```json +``` {"question": "1+1=", "answer": "2"} {"question": "2+2=", "answer": "4"} ... @@ -48,25 +50,45 @@ In the math problem scenario, the `Task` dataset can be a `jsonl` file, where ea The core of creating a new task type is writing a new `Workflow`, whose base class interface is as follows: ```python -from abc import ABC -from typing import List +# import some packages class Workflow(ABC): - def __init__(self, model: ModelWrapper, **kwargs): + def __init__( + self, + model: ModelWrapper, + task: Task, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): self.model = model + self.auxiliary_models = auxiliary_models @abstractmethod def run(self) -> List[Experience]: """Run the workflow and return a list of Experiences.""" ``` -Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflows`. +Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflow` classes. + +```python +# import some packages +from trinity.common.workflows.workflow import WORKFLOWS + +@WORKFLOWS.register_module("my_workflow") +class MyWorkflow(Workflow): + pass +``` #### Initialization Parameters When initializing, `Workflow` receives the following parameters: -- `model`: Provides an API call interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`). -- `kwargs`: Reads one line of data from the `Task` dataset, allowing developers to initialize internal modules such as Agent and Environment within the `Workflow` based on these parameters. +- `model`: The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`). +- `task`: An instance of `Task`, which is generated by one line of data from the `Task` dataset. The `raw_task` field contains the `Dict` format source data, which can be used to construct the `Workflow` instance. +The `rollout_args` field contains the parameters for the rollout process, such as `n`, `temperature`, `top_k` and `top_p`. +- `auxiliary_models`: A list of auxiliary models, which will not be trained. All of them provide OpenAI compatible API. + +```{tip} +The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance. +``` #### Example Code Below is a simple example demonstrating how to implement a math problem `Workflow`: @@ -75,10 +97,16 @@ Below is a simple example demonstrating how to implement a math problem `Workflo @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): - def __init__(self, model: ModelWrapper, **kwargs): - super().__init__(model) - self.question = kwargs.get("question") - self.answer = kwargs.get("answer") + def __init__(self, model: ModelWrapper, task: Task, **kwargs): + super().__init__(model, **kwargs) + self.question = task.raw_task.get("question") + self.answer = task.raw_task.get("answer") + + def calculate_reward(self, response: str, truth: str) -> float: + if response == truth: + return 1.0 + else: + return 0.0 def run(self) -> List[Experience]: response = self.model.chat( @@ -87,15 +115,19 @@ class ExampleWorkflow(Workflow): "role": "user", "content": f"Question:\n{self.question}", } - ] + ], + n=self.task.rollout_args.repeat_times, + temperature=self.task.rollout_args.temperature, ) - reward: float = calculate_reward(response.response_text, self.answer) - return [Experience( - tokens=response.tokens, - prompt_length=response.prompt_length, - reward=reward, - logprobs=response.logprobs, - )] + reward: float = self.calculate_reward(response.response_text, self.answer) + return [ + Experience( + tokens=response.tokens, + prompt_length=response.prompt_length, + reward=reward, + logprobs=response.logprobs, + ) + ] ``` --- diff --git a/pyproject.toml b/pyproject.toml index e678a4b6f9..1ec92220ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "flask", "requests", "tensorboard", + "openai", ] [project.scripts] diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 090cb6f2fd..f33dc06903 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer from tests.tools import RayUnittestBase, get_template_config -from trinity.common.models import create_rollout_models +from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper from trinity.common.models.utils import ( tokenize_and_mask_messages_default, @@ -127,6 +127,7 @@ def test_generate(self): ) self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask)) self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) + self.assertRaises(ValueError, self.model_wrapper.get_openai_client) class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase): @@ -139,7 +140,7 @@ def setUp(self): self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 self.config.explorer.use_v1 = False self.config.explorer.chat_template = CHAT_TEMPLATE - self.engines = create_rollout_models(self.config) + self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") @@ -153,7 +154,7 @@ def setUp(self): self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 self.config.explorer.use_v1 = False self.config.explorer.chat_template = CHAT_TEMPLATE - self.engines = create_rollout_models(self.config) + self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") @@ -166,7 +167,7 @@ def setUp(self): self.config.explorer.tensor_parallel_size = 2 self.config.explorer.use_v1 = False self.config.explorer.chat_template = CHAT_TEMPLATE - self.engines = create_rollout_models(self.config) + self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") @@ -180,7 +181,7 @@ def setUp(self): self.config.buffer.explorer_input.taskset.rollout_args.repeat_times = 2 self.config.explorer.use_v1 = True self.config.explorer.chat_template = CHAT_TEMPLATE - self.engines = create_rollout_models(self.config) + self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") @@ -193,10 +194,48 @@ def setUp(self): self.config.explorer.tensor_parallel_size = 1 self.config.explorer.use_v1 = True self.config.explorer.chat_template = CHAT_TEMPLATE - self.engines = create_rollout_models(self.config) + self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") +class TestAPIServer(RayUnittestBase): + def setUp(self): + self.config = get_template_config() + self.config.model.model_path = get_model_path() + self.config.explorer.engine_type = "vllm_async" + self.config.explorer.engine_num = 1 + self.config.explorer.tensor_parallel_size = 1 + self.config.explorer.use_v1 = True + self.config.explorer.chat_template = CHAT_TEMPLATE + self.config.explorer.enable_openai_api = True + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + + def test_api(self): + openai_client = self.model_wrapper.get_openai_client() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is your name?"}, + ] + response = openai_client.chat.completions.create( + model=self.config.model.model_path, messages=messages, n=1 + ) + self.assertEqual(1, len(response.choices)) + self.assertTrue(len(response.choices[0].message.content) > 0) + response = openai_client.chat.completions.create( + model=self.config.model.model_path, + messages=messages, + n=2, + temperature=0.5, + logprobs=True, + top_logprobs=0, + ) + self.assertEqual(2, len(response.choices)) + self.assertTrue(response.choices[0].logprobs is not None) + self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs)) + self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) + + class TestTokenizer(unittest.TestCase): def test_assistant_token_mask(self): messages = [ diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 5d63fcdc63..1dc14ded69 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -21,7 +21,7 @@ @WORKFLOWS.register_module("dummy_workflow") class DummyWorkflow(Workflow): - def __init__(self, model, task): + def __init__(self, model, task, auxiliary_models): super().__init__(model, task) self.error_type = task.task_desc self.seconds = None diff --git a/trinity/common/config.py b/trinity/common/config.py index 2508081cfb..c168bb48b1 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -137,6 +137,25 @@ class ModelConfig: enable_thinking: bool = False +@dataclass +class InferenceModelConfig: + # TODO: support setting engine_num + model_path: str = "" + tensor_parallel_size: int = 1 + use_v1: bool = True + max_prompt_tokens: int = 2048 + max_response_tokens: int = 2048 + enable_thinking: bool = False + enforce_eager: bool = True + enable_prefix_caching: bool = False + enable_chunked_prefill: bool = False + gpu_memory_utilization: float = 0.9 + dtype: str = "bfloat16" + seed: int = 42 + chat_template: Optional[str] = None + bundle_indices: str = "" # DO NOT SET this field + + @dataclass class ClusterConfig: """Config for the cluster.""" @@ -185,10 +204,10 @@ class BufferConfig: class ExplorerConfig: """Config for explorer.""" - # inference engine type, `vllm` or `vllm_async` - engine_type: str = "vllm" + # rollout engine type, `vllm` or `vllm_async` + engine_type: str = "vllm_async" - # number of inference engines + # number of rollout engines engine_num: int = 1 # number of workflow runners. @@ -199,7 +218,8 @@ class ExplorerConfig: # for rollout tokneize chat_template: Optional[str] = None - # for vLLM + # TODO: move vllm rollout model related args into + # `explorer.rollout_model: InferenceModelConfig` tensor_parallel_size: int = 1 enable_prefix_caching: bool = False enforce_eager: bool = True @@ -210,6 +230,7 @@ class ExplorerConfig: gpu_memory_utilization: float = 0.9 enable_chunked_prefill: bool = False use_v1: bool = True + enable_openai_api: bool = False bundle_indices: str = "" # DO NOT SET this field # for workflow runner @@ -218,6 +239,9 @@ class ExplorerConfig: max_timeout: int = 900 # wait each task for 15 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout + # for other models used in the custom workflows + auxiliary_models: List[InferenceModelConfig] = field(default_factory=list) + @dataclass class TrainerConfig: @@ -453,6 +477,10 @@ def check_and_update(self) -> None: # noqa: C901 if not self.model.critic_model_path: self.model.critic_model_path = self.model.model_path + # check explorer + if self.explorer.engine_type != "vllm_asyc" and self.explorer.enable_openai_api: + raise ValueError("OpenAI API server only support `vllm_async` engine.") + # check synchronizer self.synchronizer.explorer_world_size = ( self.explorer.engine_num * self.explorer.tensor_parallel_size diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 87a2dd8ced..00faa71165 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -1,13 +1,40 @@ from collections import defaultdict -from typing import List +from typing import List, Tuple -from trinity.common.config import Config +from trinity.common.config import Config, InferenceModelConfig from trinity.common.models.model import InferenceModel +from trinity.utils.log import get_logger -def create_rollout_models( +class _BundleAllocator: + """An allocator for bundles.""" + + def __init__(self, node_bundle_map: dict[str, list]) -> None: + self.logger = get_logger(__name__) + self.node_bundle_list = [value for value in node_bundle_map.values()] + self.node_list = [key for key in node_bundle_map.keys()] + self.nid = 0 + self.bid = 0 + + def allocate(self, num: int) -> list: + # allocate num bundles from current node + if self.bid + num > len(self.node_bundle_list[self.nid]): + raise ValueError( + "Bundle allocation error, a tensor parallel group" + " is allocated across multiple nodes." + ) + bundle_list = self.node_bundle_list[self.nid][self.bid : self.bid + num] + self.logger.info(f"Allocate bundles {bundle_list} on node {self.node_list[self.nid]}.") + self.bid += num + if self.bid == len(self.node_bundle_list[self.nid]): + self.bid = 0 + self.nid += 1 + return bundle_list + + +def create_inference_models( config: Config, -) -> List[InferenceModel]: +) -> Tuple[List[InferenceModel], List[InferenceModel]]: """Create `engine_num` rollout models. Each model has `tensor_parallel_size` workers. @@ -23,7 +50,10 @@ def create_rollout_models( tensor_parallel_size = config.explorer.tensor_parallel_size is_multi_process = config.explorer.tensor_parallel_size > 1 - vllm_engines = [] + if config.explorer.enable_openai_api and config.explorer.engine_type != "vllm_async": + raise ValueError("OpenAI API is only supported for vllm_async engine") + + rollout_engines = [] if config.explorer.engine_type == "vllm": engine_cls = vLLMRolloutModel @@ -32,11 +62,18 @@ def create_rollout_models( else: raise ValueError(f"Unknown engine type: {config.explorer.engine_type}") - bundles = [{"GPU": 1, "CPU": 1} for _ in range(engine_num * tensor_parallel_size)] - pg = placement_group(bundles, strategy="PACK") + main_bundles = [{"GPU": 1, "CPU": 1} for _ in range(engine_num * tensor_parallel_size)] + auxiliary_bundles = [ + {"GPU": 1, "CPU": 1} + for _ in range( + sum([model.tensor_parallel_size for model in config.explorer.auxiliary_models]) + ) + ] + pg = placement_group(main_bundles + auxiliary_bundles, strategy="PACK") ray.get(pg.ready()) - vllm_engines = [] + rollout_engines = [] + auxiliary_engines = [] # to address https://github.com/ray-project/ray/issues/51117 # aggregate bundles belonging to the same node @@ -44,27 +81,62 @@ def create_rollout_models( node_bundle_map = defaultdict(list) for bundle_id, node_id in bundle_node_map.items(): node_bundle_map[node_id].append(bundle_id) + allocator = _BundleAllocator(node_bundle_map) - for node_id, bundle_ids in node_bundle_map.items(): - assert len(bundle_ids) % tensor_parallel_size == 0, ( - f"Node {node_id} has {len(bundle_ids)} bundles, " - f"which is not divisible by tensor_parallel_size({tensor_parallel_size})" + # create rollout models + for _ in range(config.explorer.engine_num): + bundles_for_engine = allocator.allocate(config.explorer.tensor_parallel_size) + model_config = InferenceModelConfig( + model_path=config.model.model_path, + tensor_parallel_size=config.explorer.tensor_parallel_size, + use_v1=config.explorer.use_v1, + max_prompt_tokens=config.model.max_prompt_tokens, + max_response_tokens=config.model.max_response_tokens, + enforce_eager=config.explorer.enforce_eager, + enable_prefix_caching=config.explorer.enable_prefix_caching, + enable_chunked_prefill=config.explorer.enable_chunked_prefill, + enable_thinking=config.model.enable_thinking, + gpu_memory_utilization=config.explorer.gpu_memory_utilization, + dtype=config.explorer.dtype, + seed=config.explorer.seed, + chat_template=config.explorer.chat_template, + bundle_indices=",".join([str(bid) for bid in bundles_for_engine]), ) - for i in range(len(bundle_ids) // tensor_parallel_size): - bundles_for_engine = bundle_ids[ - i * tensor_parallel_size : (i + 1) * tensor_parallel_size - ] - config.explorer.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) - vllm_engines.append( - engine_cls.options( - num_cpus=0, - num_gpus=0 if is_multi_process else 1, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=bundles_for_engine[0], - ), - ).remote( - config=config, - ) + rollout_engines.append( + ray.remote(engine_cls) + .options( + num_cpus=0, + num_gpus=0 if is_multi_process else 1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundles_for_engine[0], + ), ) - return vllm_engines + .remote( + config=model_config, + ) + ) + if config.explorer.enable_openai_api: + for engine in rollout_engines: + engine.run_api_server.remote() + + # create auxiliary models + for model_config in config.explorer.auxiliary_models: + bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) + auxiliary_engines.append( + ray.remote(vLLMAysncRolloutModel) + .options( + num_cpus=0, + num_gpus=0 if is_multi_process else 1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundles_for_engine[0], + ), + ) + .remote(config=model_config) + ) + # all auxiliary engines run api server + for engine in auxiliary_engines: + engine.run_api_server.remote() + + return rollout_engines, auxiliary_engines diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 8225aa09fa..baf5ffdd8c 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -1,13 +1,16 @@ # -*- coding: utf-8 -*- """Base Model Class""" import socket +import time from abc import ABC, abstractmethod from typing import Any, List, Tuple +import openai import ray from torch import Tensor from trinity.common.experience import Experience +from trinity.utils.log import get_logger class InferenceModel(ABC): @@ -49,7 +52,7 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex def get_ckp_version(self) -> int: """Get the checkpoint version.""" - def get_address(self) -> Tuple[str, int]: + def get_available_address(self) -> Tuple[str, int]: """Get the address of the actor.""" address = ray.util.get_node_ip_address() with socket.socket() as s: @@ -65,6 +68,8 @@ class ModelWrapper: def __init__(self, model: Any, model_type: str = "vllm"): self.model = model self.use_async = model_type == "vllm_async" + self.openai_client: openai.OpenAI = None + self.logger = get_logger(__name__) def generate(self, prompts: List[str], **kwargs) -> List[Experience]: if self.use_async: @@ -96,3 +101,30 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: def get_ckp_version(self) -> int: return ray.get(self.model.get_ckp_version.remote()) + + def get_openai_client(self) -> openai.OpenAI: + if self.openai_client is not None: + return self.openai_client + if not ray.get(self.model.has_api_server.remote()): + raise ValueError( + "OpenAI API server is not running on current model." + "Please set `explorer.enable_openai_api` to `True`." + ) + api_address = None + while True: + api_address = ray.get(self.model.api_server_ready.remote()) + if api_address is not None: + break + else: + self.logger.info("Waiting for OpenAI API server to be ready...") + time.sleep(5) + if api_address is None: + raise RuntimeError( + "Failed to connect to the API server. Please check the API server is running." + ) + self.logger.info(f"Successfully connect to API server at {api_address}") + self.openai_client = openai.OpenAI( + base_url=api_address, + api_key="EMPTY", + ) + return self.openai_client diff --git a/trinity/common/models/openai_api.py b/trinity/common/models/openai_api.py new file mode 100644 index 0000000000..c26b0ca54b --- /dev/null +++ b/trinity/common/models/openai_api.py @@ -0,0 +1,79 @@ +"""OpenAI API server related tools. + +Modified from vllm/entrypoints/openai/api_server.py +""" +import asyncio +import functools + +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.openai.api_server import ( + build_app, + create_server_socket, + init_app_state, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils import FlexibleArgumentParser, set_ulimit + + +async def run_server_in_ray(args, engine_client): + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host, args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + app = build_app(args) + + vllm_config = await engine_client.get_vllm_config() + await init_app_state(engine_client, vllm_config, app.state, args) + + await patch_and_serve_http(app, sock, args) + + # # NB: Await server shutdown only after the backend context is exited + # try: + # await shutdown_task + # finally: + # sock.close() + + +def dummy_add_signal_handler(self, *args, **kwargs): + # DO NOTHING HERE + pass + + +async def patch_and_serve_http(app, sock, args): + """Patch the add_signal_handler method and serve the app.""" + loop = asyncio.get_event_loop() + original_add_signal_handler = loop.add_signal_handler + loop.add_signal_handler = functools.partial(dummy_add_signal_handler, loop) + + try: + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level="info", + access_log=True, + timeout_keep_alive=10, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + ) + await shutdown_task + finally: + loop.add_signal_handler = original_add_signal_handler + sock.close() + + +async def run_api_server_in_ray_actor(async_llm, host: str, port: int, model_path: str): + parser = FlexibleArgumentParser(description="Run the OpenAI API server.") + args = make_arg_parser(parser) + args = parser.parse_args(["--host", str(host), "--port", str(port), "--model", model_path]) + print(args) + await run_server_in_ray(args, async_llm) diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index fc39c3c303..ae5c4db9c1 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -3,18 +3,16 @@ Modified from Ray python/ray/llm/_internal/batch/stages/vllm_engine_stage.py """ -import asyncio import os import re -from contextlib import nullcontext from typing import Any, Dict, List, Optional -import ray +import aiohttp import torch import vllm from vllm.sampling_params import RequestOutputKind -from trinity.common.config import Config +from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel from trinity.common.models.utils import ( @@ -28,7 +26,6 @@ # TODO: merge into vLLMRolloutModel # TODO: remove V0 when V1 is stable -@ray.remote class vLLMAysncRolloutModel(InferenceModel): """Wrapper around the vLLM engine to handle async requests. @@ -39,59 +36,56 @@ class vLLMAysncRolloutModel(InferenceModel): def __init__( self, - config: Config, - **kwargs, + config: InferenceModelConfig, ) -> None: self.logger = get_logger(__name__) self.config = config - self.use_v1 = config.explorer.use_v1 - if config.explorer.tensor_parallel_size != 1: + self.use_v1 = config.use_v1 + if config.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.explorer.bundle_indices + os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices if not vllm.envs.is_set("VLLM_USE_V1"): - self.logger.info(f"Using vLLM v{int(config.explorer.use_v1)} engine") - os.environ["VLLM_USE_V1"] = str(int(config.explorer.use_v1)) - if config.explorer.use_v1: - os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.explorer.use_v1)) + self.logger.info(f"Using vLLM v{int(config.use_v1)} engine") + os.environ["VLLM_USE_V1"] = str(int(config.use_v1)) + if config.use_v1: + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" self.default_sampling_params = vllm.SamplingParams( n=1, temperature=0.0, - max_tokens=config.model.max_response_tokens, + max_tokens=config.max_response_tokens, min_tokens=1, - truncate_prompt_tokens=config.model.max_prompt_tokens, + truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, output_kind=RequestOutputKind.FINAL_ONLY, logprobs=0, ) - self.enable_thinking = config.model.enable_thinking + self.enable_thinking = config.enable_thinking self.request_id = 0 engine_args = vllm.AsyncEngineArgs( - model=config.model.model_path, - enforce_eager=config.explorer.enforce_eager, + model=config.model_path, + enforce_eager=config.enforce_eager, worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension", - tensor_parallel_size=config.explorer.tensor_parallel_size, - seed=config.explorer.seed, - distributed_executor_backend=( - "uni" if config.explorer.tensor_parallel_size == 1 else "ray" - ), - max_model_len=config.model.max_prompt_tokens + config.model.max_response_tokens, - enable_prefix_caching=config.explorer.enable_prefix_caching, - dtype=config.explorer.dtype, + tensor_parallel_size=config.tensor_parallel_size, + seed=config.seed, + distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"), + max_model_len=config.max_prompt_tokens + config.max_response_tokens, + enable_prefix_caching=config.enable_prefix_caching, + dtype=config.dtype, trust_remote_code=True, task="generate", disable_log_requests=True, - gpu_memory_utilization=config.explorer.gpu_memory_utilization, - enable_chunked_prefill=config.explorer.enable_chunked_prefill, + gpu_memory_utilization=config.gpu_memory_utilization, + enable_chunked_prefill=config.enable_chunked_prefill, # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage ) self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) self.tokenizer = None self.chat_template = None - if self.config.explorer.chat_template: - self.chat_template = self.config.explorer.chat_template + if self.config.chat_template: + self.chat_template = self.config.chat_template if self.chat_template is None or not re.search( r"\{\%-?\s*generation\s*-?\%\}", self.chat_template ): @@ -103,14 +97,9 @@ def __init__( self.action_mask_method = tokenize_and_mask_messages_default else: self.action_mask_method = tokenize_and_mask_messages_hf - # The performance gets really bad if there are too many requests in the pending queue. - # We work around it with semaphore to limit the number of concurrent requests in the engine. - self.max_pending_requests = config.explorer.max_pending_requests - if self.max_pending_requests > 0: - self.semaphore = asyncio.Semaphore(self.max_pending_requests) - else: - self.semaphore = nullcontext() self.ckp_version = 0 # TODO: resume the value from the checkpoint + self.api_server_host = None + self.api_server_port = None async def chat_async(self, messages: List[Dict], **kwargs) -> List[Experience]: """Chat with the model with a list of messages in async. @@ -153,8 +142,7 @@ async def generate_async(self, prompt: str, **kwargs) -> List[Experience]: Returns: A list of experiences. """ - async with self.semaphore: - output = await self._generate_internal(prompt=prompt, **kwargs) + output = await self._generate_internal(prompt=prompt, **kwargs) experiences = [ Experience( tokens=torch.cat( @@ -189,13 +177,12 @@ async def generate_async(self, prompt: str, **kwargs) -> List[Experience]: async def logprobs_async(self, token_ids: List[int]) -> torch.Tensor: """Calculate the logprobs of the given tokens in async.""" - async with self.semaphore: - output = await self._generate_internal( - prompt={"prompt_token_ids": token_ids}, - n=1, - max_tokens=1, - prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token - ) + output = await self._generate_internal( + prompt={"prompt_token_ids": token_ids}, + n=1, + max_tokens=1, + prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token + ) return torch.tensor( [0] + [ @@ -310,6 +297,46 @@ async def init_process_group( async def update_weight(self, name, dtype, shape, empty_cache=False): return await self._collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) + async def run_api_server(self): + """Run the OpenAI API server in a Ray actor. + + Note: + Do not use `ray.get()` on this method. + This method will run forever until the server is shut down. + """ + if not (self.api_server_host is None or self.api_server_port is None): + raise RuntimeError("API server is already running.") + from trinity.common.models.openai_api import run_api_server_in_ray_actor + + self.api_server_host, self.api_server_port = self.get_available_address() + await run_api_server_in_ray_actor( + self.async_llm, self.api_server_host, self.api_server_port, self.config.model_path + ) + + async def has_api_server(self) -> bool: + return self.api_server_host is not None and self.api_server_port is not None + + async def api_server_ready(self) -> Optional[str]: + """Check if the OpenAI API server is ready. + + Returns: + str: The URL of the OpenAI API server. + """ + if not await self.has_api_server(): + return None + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://{self.api_server_host}:{self.api_server_port}/health" + ) as response: + if response.status == 200: + return f"http://{self.api_server_host}:{self.api_server_port}/v1" + else: + return None + except Exception as e: + self.logger.error(e) + return None + async def reset_prefix_cache(self) -> None: await self.async_llm.reset_prefix_cache() diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 1e3efe6d22..32ab98fe8a 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -10,13 +10,12 @@ import threading from typing import List -import ray import torch import vllm from vllm import LLM from vllm.sampling_params import SamplingParams -from trinity.common.config import Config +from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel from trinity.common.models.utils import ( @@ -26,57 +25,53 @@ from trinity.utils.log import get_logger -@ray.remote class vLLMRolloutModel(InferenceModel): """Actor for vLLM.""" - def __init__(self, config: Config, **kwargs): + def __init__(self, config: InferenceModelConfig): self.logger = get_logger(__name__) self.config = config - if config.explorer.tensor_parallel_size != 1: + if config.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.explorer.bundle_indices + os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices if not vllm.envs.is_set("VLLM_USE_V1"): - self.logger.info(f"Using vLLM v{int(config.explorer.use_v1)} engine") - os.environ["VLLM_USE_V1"] = str(int(config.explorer.use_v1)) - if config.explorer.use_v1: - os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.explorer.use_v1)) + self.logger.info(f"Using vLLM v{int(config.use_v1)} engine") + os.environ["VLLM_USE_V1"] = str(int(config.use_v1)) + if config.use_v1: + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" self.default_sampling_params = SamplingParams( n=1, temperature=0.0, - max_tokens=config.model.max_response_tokens, + max_tokens=config.max_response_tokens, min_tokens=1, - truncate_prompt_tokens=config.model.max_prompt_tokens, + truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, logprobs=0, ) self.llm = LLM( # TODO: check checkpoint path - model=config.model.model_path, - enforce_eager=config.explorer.enforce_eager, + model=config.model_path, + enforce_eager=config.enforce_eager, worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension", - tensor_parallel_size=config.explorer.tensor_parallel_size, - seed=config.explorer.seed, - distributed_executor_backend=( - "uni" if config.explorer.tensor_parallel_size == 1 else "ray" - ), - max_model_len=config.model.max_prompt_tokens + config.model.max_response_tokens, - enable_prefix_caching=config.explorer.enable_prefix_caching, - dtype=config.explorer.dtype, + tensor_parallel_size=config.tensor_parallel_size, + seed=config.seed, + distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"), + max_model_len=config.max_prompt_tokens + config.max_response_tokens, + enable_prefix_caching=config.enable_prefix_caching, + dtype=config.dtype, trust_remote_code=True, - gpu_memory_utilization=config.explorer.gpu_memory_utilization, - enable_chunked_prefill=config.explorer.enable_chunked_prefill, + gpu_memory_utilization=config.gpu_memory_utilization, + enable_chunked_prefill=config.enable_chunked_prefill, # max_num_batched_tokens=256, - **kwargs, ) self.tokenizer = self.llm.get_tokenizer() self.chat_template = self.tokenizer.get_chat_template() - self.enable_thinking = config.model.enable_thinking - if self.config.explorer.chat_template: - self.chat_template = self.config.explorer.chat_template + self.enable_thinking = config.enable_thinking + if self.config.chat_template: + self.chat_template = self.config.chat_template if not re.search(r"\{\%-?\s*generation\s*-?\%\}", self.chat_template): self.logger.warning( "The provided chat template does not support `return_assitant_tokens_mask`. " @@ -170,12 +165,11 @@ def generate(self, prompts: List[str], **kwargs) -> List: ] """ with self.lock: - outputs = self.llm.generate( - prompts, self._create_sampling_params(**kwargs), use_tqdm=False - ) + sampling_params = self._create_sampling_params(**kwargs) + outputs = self.llm.generate(prompts, sampling_params, use_tqdm=False) experiences = [] for output in outputs: - for i in range(self.config.buffer.explorer_input.taskset.rollout_args.repeat_times): + for i in range(sampling_params.n): experiences.append( Experience( tokens=torch.cat( @@ -274,6 +268,9 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: action_mask=action_mask, ) + def has_api_server(self) -> bool: + return False + def sync_model(self, update_weight_args_list) -> bool: """Sync model weights to vLLM.""" with self.lock: diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index f0f323918b..b988161723 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -270,7 +270,9 @@ class veRLConfig: def synchronize_config(self, config: Config) -> None: """Synchronize config.""" - rollout_gpu_num = config.explorer.tensor_parallel_size * config.explorer.engine_num + rollout_gpu_num = config.explorer.tensor_parallel_size * config.explorer.engine_num + sum( + [model.tensor_parallel_size for model in config.explorer.auxiliary_models] + ) rollout_node_num = rollout_gpu_num // config.cluster.gpu_per_node self.trainer.nnodes = config.cluster.node_num - rollout_node_num self.actor_rollout_ref.model.path = config.model.model_path diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 9f8b389858..9b3b0d79d4 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from typing import List +from typing import List, Optional from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper @@ -100,6 +100,7 @@ def __init__( self, model: ModelWrapper, task: Task, + auxiliary_models: Optional[List] = None, ): super().__init__( model=model, diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index b3669f01d0..60bc6c4d81 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import json -from typing import List +from typing import List, Optional from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper @@ -63,6 +63,7 @@ def __init__( self, model: ModelWrapper, task: Task, + auxiliary_models: Optional[List] = None, ): super().__init__( model=model, diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 55035fd7b4..6b961116d0 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from typing import List +from typing import List, Optional from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper @@ -185,6 +185,7 @@ def __init__( self, model: ModelWrapper, task: Task, + auxiliary_models: Optional[List] = None, ): super().__init__( model=model, diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 905fe2e5b8..d44d9e9813 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -7,6 +7,7 @@ from dataclasses import asdict, dataclass, field from typing import Any, List, Optional, Type, Union +import openai import torch from trinity.common.config import FormatConfig, GenerationConfig @@ -33,7 +34,9 @@ class Task: reward_fn: Optional[Type[RewardFn]] = None raw_task: Optional[dict] = None # The raw data sample - def to_workflow(self, model: Any) -> Workflow: + def to_workflow( + self, model: Any, auxiliary_models: Optional[List[openai.OpenAI]] = None + ) -> Workflow: """Convert the task to a workflow. Args: @@ -45,6 +48,7 @@ def to_workflow(self, model: Any) -> Workflow: return self.workflow( model=model, task=self, + auxiliary_models=auxiliary_models, ) @property @@ -68,8 +72,10 @@ def __init__( self, model: ModelWrapper, task: Task, + auxiliary_models: Optional[List[openai.OpenAI]] = None, ): self.model = model + self.auxiliary_models = auxiliary_models @abstractmethod def run(self) -> List[Experience]: @@ -85,10 +91,12 @@ def __init__( self, model: ModelWrapper, task: Task, + auxiliary_models: Optional[List[openai.OpenAI]] = None, ): super().__init__( model=model, task=task, + auxiliary_models=auxiliary_models, ) @abstractmethod @@ -133,6 +141,7 @@ def __init__( self, model: ModelWrapper, task: Task, + auxiliary_models: Optional[List[openai.OpenAI]] = None, ): super().__init__( model=model, @@ -198,6 +207,7 @@ def __init__( self, model: ModelWrapper, task: Task, + auxiliary_models: Optional[List[openai.OpenAI]] = None, ): if task.reward_fn is None: task.reward_fn = MathRewardFn diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 54ad581562..f2b2490e1d 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -12,7 +12,7 @@ from trinity.buffer.buffer import get_buffer_reader from trinity.common.config import Config from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod -from trinity.common.models import create_rollout_models +from trinity.common.models import create_inference_models from trinity.common.models.utils import ( get_checkpoint_dir_with_step_num, load_state_dict, @@ -33,7 +33,7 @@ def __init__(self, config: Config): explorer_meta = self.cache.load_explorer() self.step_num = explorer_meta.get("latest_iteration", 0) self.config = config - self.models = create_rollout_models(config) + self.models, self.auxiliary_models = create_inference_models(config) if self.config.mode != "bench": self.experience_buffer = get_buffer_writer( self.config.buffer.explorer_output, # type: ignore @@ -147,7 +147,7 @@ def _nccl_weights_update(self): def prepare(self) -> None: """Preparation before running.""" if self.use_checkpoint_weights_update: - master_address, master_port = ray.get(self.models[0].get_address.remote()) + master_address, master_port = ray.get(self.models[0].get_available_address.remote()) self.setup_weight_sync_group(master_address, master_port) @ray.method(concurrency_group="get_weight") diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 87e80aaf9b..e60821347a 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -30,14 +30,30 @@ class Status: class WorkflowRunner: """A Ray remote actor to run the workflow and put the returned experiences into the buffer.""" - def __init__(self, config: Config, model: InferenceModel) -> None: + def __init__( + self, + config: Config, + model: InferenceModel, + auxiliary_models: Optional[List[InferenceModel]] = None, + ) -> None: self.config = config self.experience_buffer = get_buffer_writer( self.config.buffer.explorer_output, # type: ignore self.config.buffer, ) self.model = model - self.model_wrapper = ModelWrapper(model, config.explorer.engine_type) + self.model_wrapper = ModelWrapper( + model, + config.explorer.engine_type, + ) + self.auxiliary_models = [] + if auxiliary_models is not None: + for model in auxiliary_models: + api_client = ModelWrapper( + model, + "vllm_async", + ).get_openai_client() + self.auxiliary_models.append(api_client) self.logger = get_logger(__name__) def is_alive(self): @@ -47,7 +63,7 @@ def _run_task(self, task: Task) -> List[Experience]: """Init workflow from the task and run it.""" if task.workflow is None: raise ValueError("Workflow is not set in the task.") - workflow = task.to_workflow(self.model_wrapper) + workflow = task.to_workflow(self.model_wrapper, self.auxiliary_models) return workflow.run() def run_task(self, task: Task) -> Status: