diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index f34f1dffd4..1c8ba2e36e 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -321,9 +321,10 @@ explorer: max_retry_times: 2 env_vars: {} rollout_model: - engine_type: vllm_async + engine_type: vllm engine_num: 1 tensor_parallel_size: 1 + enable_history: False auxiliary_models: - model_path: /PATH/TO/MODEL tensor_parallel_size: 1 @@ -336,9 +337,10 @@ explorer: - `max_timeout`: Maximum time (in seconds) for a workflow to complete. - `max_retry_times`: Maximum number of retries for a workflow. - `env_vars`: Environment variables to be set for every workflow runners. -- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`. +- `rollout_model.engine_type`: Type of inference engine. For now, only `vllm_async` and `vllm` is supported, they have the same meaning and both use the asynchronous engine. In subsequent versions, only `vllm` may be retained for simplicity. - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. +- `rollout_model.enable_history`: Whether to enable model call history recording. If set to `True`, the model wrapper automatically records the return experiences of model calls. Please periodically extract the history via `extract_experience_from_history` to avoid out-of-memory issues. Default is `False`. - `auxiliary_models`: Additional models used for custom workflows. - `eval_interval`: Interval (in steps) for evaluating the model. - `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting. diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index fb75d084b1..a7c92bef61 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -122,6 +122,7 @@ During initialization, `Workflow` receives the following parameters: ```{tip} You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. +And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id`. ``` Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization. diff --git a/pyproject.toml b/pyproject.toml index efcc8a7aab..4fa8d764df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "trinity-rft" -version = "0.2.0" +version = "0.2.1.dev0" authors = [ {name="Trinity-RFT Team", email="trinity-rft@outlook.com"}, ] diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 0146eb075c..71ccf32b7f 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -2,9 +2,10 @@ import unittest import torch +from parameterized import parameterized_class from transformers import AutoTokenizer -from tests.tools import RayUnittestBase, get_template_config +from tests.tools import RayUnittestBase, RayUnittestBaseAysnc, get_template_config from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper from trinity.common.models.utils import ( @@ -82,12 +83,59 @@ def get_model_path() -> str: """ -class BaseTestModelWrapper: - def test_generate(self): +@parameterized_class( + ("tensor_parallel_size", "engine_num", "use_v1", "repeat_times", "enable_history", "use_async"), + [ + (1, 2, False, 2, True, False), + (2, 2, False, 1, False, True), + (2, 2, True, 2, True, False), + (1, 2, True, 1, False, True), + (2, 1, True, 3, True, True), + ], +) +class ModelWrapperTest(RayUnittestBaseAysnc): + def setUp(self): + # configure the model + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.explorer.rollout_model.engine_num = self.engine_num + self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size + self.config.explorer.rollout_model.use_v1 = self.use_v1 + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.algorithm.repeat_times = self.repeat_times + self.config.explorer.rollout_model.enable_history = self.enable_history + self.config.check_and_update() + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper( + self.engines[0], model_type="vllm_async", enable_history=self.enable_history + ) + + async def test_generate( + self, + ): prompts = ["Hello, world!", "Hello, my name is"] n = self.config.algorithm.repeat_times - results = self.model_wrapper.generate(prompts, n=n, temperature=1.0) - self.assertEqual(len(results), len(prompts) * n) + if self.use_async: + generate_results = await self.model_wrapper.generate_async( + prompts, n=n, temperature=1.0 + ) + else: + generate_results = self.model_wrapper.generate(prompts, n=n, temperature=1.0) + self.assertEqual(len(generate_results), len(prompts) * n) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history( + clear_history=False + ) + self.assertEqual(len(history_experiences), len(generate_results)) + for exp, history_exp in zip(generate_results, history_experiences): + self.assertEqual(exp.response_text, history_exp.response_text) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.prompt_length, history_exp.prompt_length) + self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) + else: + with self.assertRaises(ValueError): + self.model_wrapper.extract_experience_from_history(clear_history=False) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's the weather like today?"}, @@ -97,15 +145,32 @@ def test_generate(self): }, {"role": "user", "content": "OK, thanks!"}, ] - results = self.model_wrapper.chat(messages, n=n, temperature=1.0) + if self.use_async: + results = await self.model_wrapper.chat_async(messages, n=n, temperature=1.0) + else: + results = self.model_wrapper.chat(messages, n=n, temperature=1.0) self.assertEqual(len(results), n) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(history_experiences) - len(generate_results), len(results)) + for exp, history_exp in zip(results, history_experiences[len(generate_results) :]): + self.assertEqual(exp.response_text, history_exp.response_text) + self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist()) + self.assertEqual(exp.prompt_length, history_exp.prompt_length) + self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) for result in results: input_logprobs = result.logprobs[: result.prompt_length] output_logprobs = result.logprobs[result.prompt_length :] self.assertTrue(torch.all(input_logprobs == 0)) self.assertTrue(torch.any(output_logprobs != 0)) - logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) + if self.use_async: + logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist()) + else: + logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertTrue(len(history_experiences) == 0) messages.append( { "role": "assistant", @@ -128,84 +193,9 @@ def test_generate(self): self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask)) self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) self.assertRaises(ValueError, self.model_wrapper.get_openai_client) - - -class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm" - self.config.explorer.rollout_model.tensor_parallel_size = 1 - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.use_v1 = False - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.algorithm.repeat_times = 2 - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") - - -class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.tensor_parallel_size = 1 - self.config.explorer.rollout_model.use_v1 = False - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.algorithm.repeat_times = 2 - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") - - -class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.tensor_parallel_size = 2 - self.config.explorer.rollout_model.use_v1 = False - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") - - -class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.tensor_parallel_size = 2 - self.config.explorer.rollout_model.use_v1 = True - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.algorithm.repeat_times = 2 - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") - - -class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase): - def setUp(self): - self.config = get_template_config() - self.config.mode = "explore" - self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" - self.config.explorer.rollout_model.engine_num = 2 - self.config.explorer.rollout_model.tensor_parallel_size = 1 - self.config.explorer.rollout_model.use_v1 = True - self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE - self.config.check_and_update() - self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + if self.config.explorer.rollout_model.enable_history: + history_experiences = self.model_wrapper.extract_experience_from_history() + self.assertTrue(len(history_experiences) == 0) class TestAPIServer(RayUnittestBase): @@ -221,7 +211,12 @@ def setUp(self): self.config.explorer.rollout_model.enable_openai_api = True self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + self.model_wrapper = ModelWrapper( + self.engines[0], model_type="vllm_async", enable_history=True + ) + self.model_wrapper_no_history = ModelWrapper( + self.engines[0], model_type="vllm_async", enable_history=False + ) def test_api(self): openai_client = self.model_wrapper.get_openai_client() @@ -229,13 +224,12 @@ def test_api(self): {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is your name?"}, ] - response = openai_client.chat.completions.create( - model=self.config.model.model_path, messages=messages, n=1 - ) + model_id = openai_client.models.list().data[0].id + response = openai_client.chat.completions.create(model=model_id, messages=messages, n=1) self.assertEqual(1, len(response.choices)) self.assertTrue(len(response.choices[0].message.content) > 0) response = openai_client.chat.completions.create( - model=self.config.model.model_path, + model=model_id, messages=messages, n=2, temperature=0.5, @@ -246,6 +240,32 @@ def test_api(self): self.assertTrue(response.choices[0].logprobs is not None) self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs)) self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) + self.assertTrue(hasattr(response, "prompt_token_ids")) + self.assertTrue(len(response.prompt_token_ids) > 0) + self.assertTrue(hasattr(response.choices[0], "token_ids")) + self.assertTrue(len(response.choices[0].token_ids) > 0) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 3) + response = openai_client.chat.completions.create( + model=model_id, + messages=messages, + n=4, + temperature=0.5, + logprobs=True, + top_logprobs=0, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 4) + self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0) + response = self.model_wrapper_no_history.get_openai_client().chat.completions.create( + model=model_id, messages=messages, n=2 + ) + self.assertEqual(2, len(response.choices)) + self.assertTrue(hasattr(response.choices[0], "token_ids")) + self.assertTrue(len(response.choices[0].token_ids) > 0) + with self.assertRaises(ValueError): + self.model_wrapper_no_history.extract_experience_from_history() + self.assertEqual(len(self.model_wrapper_no_history.history), 0) class TestTokenizer(unittest.TestCase): diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 20ecb179c9..d25c394bff 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -1,7 +1,7 @@ import asyncio import time import unittest -from typing import List, Tuple +from typing import List import ray import torch @@ -98,8 +98,8 @@ def init_process_group( def has_api_server(self) -> bool: return True - def api_server_ready(self) -> Tuple[str, str]: - return "http://localhosts:12345", "placeholder" + def api_server_ready(self) -> str: + return "http://localhosts:12345" def generate_tasks( diff --git a/trinity/__init__.py b/trinity/__init__.py index dc3b8ca098..63f1db4fdc 100644 --- a/trinity/__init__.py +++ b/trinity/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """Trinity-RFT (Reinforcement Fine-Tuning)""" -__version__ = "0.2.0" +__version__ = "0.2.1.dev0" diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index c6858931b1..80a4af7d49 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -4,7 +4,6 @@ import numpy as np import torch -from verl.trainer.ppo.ray_trainer import DataProto from trinity.algorithm.sample_strategy.sample_strategy import ( SAMPLE_STRATEGY, @@ -85,7 +84,9 @@ def default_args(cls) -> Dict: } -def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: +def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor): + from verl.trainer.ppo.ray_trainer import DataProto + attention_mask = experiences.attention_masks cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() diff --git a/trinity/common/config.py b/trinity/common/config.py index 1e0bcc5e9d..3a8bcd7f91 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -203,6 +203,9 @@ class InferenceModelConfig: # For Qwen3 enable_thinking: bool = False + # For history recording + enable_history: bool = False + # For OpenAI API enable_openai_api: bool = False @@ -310,8 +313,6 @@ class ExplorerConfig: name: str = EXPLORER_NAME # for workflow runner # number of workflow runners. - # For sync engine (vllm), it should be `1`. - # For async engine (vllm_async), it could be a large number. runner_per_model: int = 8 # number of runners per each rollout model max_timeout: int = 1800 # wait each task for 30 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout @@ -719,11 +720,6 @@ def check_and_update(self) -> None: # noqa: C901 self.model.critic_model_path = self.model.model_path # check explorer - if ( - self.explorer.rollout_model.engine_type != "vllm_async" - and self.explorer.rollout_model.enable_openai_api - ): - raise ValueError("OpenAI API server only support `vllm_async` engine.") if self.explorer.rollout_model.max_prompt_tokens is None: self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens if self.explorer.rollout_model.max_response_tokens is None: diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index f9d092807c..fb85590b06 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -43,24 +43,15 @@ def create_inference_models( from ray.util.placement_group import placement_group, placement_group_table from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - from trinity.common.models.vllm_async_model import vLLMAysncRolloutModel from trinity.common.models.vllm_model import vLLMRolloutModel + logger = get_logger(__name__) engine_num = config.explorer.rollout_model.engine_num tensor_parallel_size = config.explorer.rollout_model.tensor_parallel_size - if ( - config.explorer.rollout_model.enable_openai_api - and config.explorer.rollout_model.engine_type != "vllm_async" - ): - raise ValueError("OpenAI API is only supported for vllm_async engine") - rollout_engines = [] - - if config.explorer.rollout_model.engine_type == "vllm": + if config.explorer.rollout_model.engine_type.startswith("vllm"): engine_cls = vLLMRolloutModel - elif config.explorer.rollout_model.engine_type == "vllm_async": - engine_cls = vLLMAysncRolloutModel else: raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}") @@ -115,17 +106,21 @@ def create_inference_models( if config.explorer.rollout_model.enable_openai_api: for engine in rollout_engines: engine.run_api_server.remote() - + if config.explorer.rollout_model.enable_history: + logger.info( + "Model History recording is enabled. Please periodically extract " + "history via `extract_experience_from_history` to avoid out-of-memory issues." + ) # create auxiliary models for model_config in config.explorer.auxiliary_models: engines = [] for _ in range(model_config.engine_num): bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) model_config.enable_openai_api = True - model_config.engine_type = "vllm_async" + model_config.engine_type = "vllm" model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) engines.append( - ray.remote(vLLMAysncRolloutModel) + ray.remote(vLLMRolloutModel) .options( num_cpus=0, num_gpus=0 if model_config.tensor_parallel_size > 1 else 1, diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/api/vllm_patch.py new file mode 100644 index 0000000000..438636e35e --- /dev/null +++ b/trinity/common/models/api/vllm_patch.py @@ -0,0 +1,330 @@ +"""Patch for vllm OpenAI API server. + +1. Mocks the `add_signal_handler` method to do nothing. +2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. +""" +import asyncio +import functools +import json +import time +from typing import Optional, Union + +from pydantic import Field, TypeAdapter +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.openai.api_server import ( + build_app, + create_server_socket, + init_app_state, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + ErrorResponse, + FunctionCall, + FunctionDefinition, + PromptTokenUsageInfo, + ToolCall, + UsageInfo, +) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_engine import clamp_prompt_logprobs +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.outputs import RequestOutput +from vllm.transformers_utils.tokenizer import MistralTokenizer +from vllm.utils import FlexibleArgumentParser, set_ulimit + +from trinity.utils.log import get_logger + + +class PatchedChatCompletionResponseChoice(ChatCompletionResponseChoice): + token_ids: list[int] = Field(default_factory=list) + + +class PatchedChatCompletionResponse(ChatCompletionResponse): + prompt_token_ids: list[int] = Field(default_factory=list) + choices: list[PatchedChatCompletionResponseChoice] = list[ChatCompletionResponseChoice] + + +# TODO: add patch to stream generator +async def chat_completion_full_generator( # noqa C901 + self, + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, +) -> Union[ErrorResponse, ChatCompletionResponse]: + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + logger = get_logger(__name__) + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert final_res is not None + + choices: list[ChatCompletionResponseChoice] = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + auto_tools_called = False + + if self.reasoning_parser: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) + # If the reasoning parser is enabled, + # tool calls are extracted exclusively from the content. + reasoning_content, content = reasoning_parser.extract_reasoning_content( + output.text, request=request + ) + else: + reasoning_content = None + content = output.text + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if (not self.enable_auto_tools or not self.tool_parser) and ( + not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) + and request.tool_choice != "required" + ): + message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) + + # if the request uses tools and specified a tool choice + elif ( + request.tool_choice and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam + ): + tool_call_class = ( + MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall + ) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content="", + tool_calls=[ + tool_call_class( + function=FunctionCall( + name=request.tool_choice.function.name, arguments=content + ) + ) + ], + ) + + elif request.tool_choice and request.tool_choice == "required": + tool_call_class = ( + MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall + ) + + # the fields of FunctionDefinition are a superset of the + # tool call outputs and can be used for parsing + assert content is not None + tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content) + message = ChatMessage( + role=role, + content="", + tool_calls=[ + tool_call_class( + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters, ensure_ascii=False), + ) + ) + for tool_call in tool_calls + ], + ) + + # if the request doesn't use tool choice + # OR specifies to not use a tool + elif not request.tool_choice or request.tool_choice == "none": + message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) + + # handle when there are tools and tool choice is auto + elif ( + request.tools + and (request.tool_choice == "auto" or request.tool_choice is None) + and self.enable_auto_tools + and self.tool_parser + ): + try: + tool_parser = self.tool_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + return self.create_error_response(str(e)) + + tool_call_info = tool_parser.extract_tool_calls( + content if content is not None else "", request=request + ) + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls, + ) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + message = ChatMessage( + role=role, reasoning_content=reasoning_content, content=content + ) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion." + ) + message = ChatMessage(role=role, reasoning_content=reasoning_content, content=content) + + choice_data = PatchedChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" + if auto_tools_called + else output.finish_reason + if output.finish_reason + else "stop", + stop_reason=output.stop_reason, + token_ids=output.token_ids, + ) + choices.append(choice_data) + + if request.echo: + last_msg_content: Union[str, list[dict[str, str]]] = "" + if conversation and "content" in conversation[-1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + if isinstance(last_msg_content, list): + last_msg_content = "\n".join(msg["text"] for msg in last_msg_content) + + for choice in choices: + full_message = last_msg_content + (choice.message.content or "") + choice.message.content = full_message + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens + ) + + request_metadata.final_usage_info = usage + + response = PatchedChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + kv_transfer_params=final_res.kv_transfer_params, + prompt_token_ids=final_res.prompt_token_ids, + ) + + return response + + +async def run_server_in_ray(args, engine_client): + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host, args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + app = build_app(args) + + vllm_config = await engine_client.get_vllm_config() + await init_app_state(engine_client, vllm_config, app.state, args) + + await patch_and_serve_http(app, sock, args) + + # # NB: Await server shutdown only after the backend context is exited + # try: + # await shutdown_task + # finally: + # sock.close() + + +def dummy_add_signal_handler(self, *args, **kwargs): + # DO NOTHING HERE + pass + + +async def patch_and_serve_http(app, sock, args): + """Patch the add_signal_handler method and serve the app.""" + loop = asyncio.get_event_loop() + original_add_signal_handler = loop.add_signal_handler + loop.add_signal_handler = functools.partial(dummy_add_signal_handler, loop) + OpenAIServingChat.chat_completion_full_generator = chat_completion_full_generator + + try: + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level="info", + access_log=True, + timeout_keep_alive=10, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + ) + await shutdown_task + finally: + loop.add_signal_handler = original_add_signal_handler + sock.close() + + +async def run_api_server_in_ray_actor(async_llm, host: str, port: int, model_path: str): + parser = FlexibleArgumentParser(description="Run the OpenAI API server.") + args = make_arg_parser(parser) + args = parser.parse_args(["--host", str(host), "--port", str(port), "--model", model_path]) + print(args) + await run_server_in_ray(args, async_llm) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 1e3eb87058..2568c88005 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- """Base Model Class""" +import asyncio import socket import time from abc import ABC, abstractmethod -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Union import openai import ray +import torch from torch import Tensor from trinity.common.experience import Experience @@ -16,35 +18,19 @@ class InferenceModel(ABC): """A model for high performance for rollout inference.""" - def generate(self, prompts: List[str], **kwargs) -> List[Experience]: - """Generate a batch of responses from a batch of prompts.""" - raise NotImplementedError - - def chat(self, messages: List[dict], **kwargs) -> List[Experience]: - """Generate experiences from a list of history chat messages.""" - raise NotImplementedError - - def logprobs(self, token_ids: List[int]) -> Tensor: - """Generate logprobs for a list of tokens.""" - raise NotImplementedError - - def convert_messages_to_experience(self, messages: List[dict]) -> Experience: - """Convert a list of messages into an experience.""" - raise NotImplementedError - - async def generate_async(self, prompt: str, **kwargs) -> List[Experience]: + async def generate(self, prompt: str, **kwargs) -> List[Experience]: """Generate a responses from a prompt in async.""" raise NotImplementedError - async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]: + async def chat(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate experiences from a list of history chat messages in async.""" raise NotImplementedError - async def logprobs_async(self, tokens: List[int]) -> Tensor: + async def logprobs(self, tokens: List[int]) -> Tensor: """Generate logprobs for a list of tokens in async.""" raise NotImplementedError - async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience: + async def convert_messages_to_experience(self, messages: List[dict]) -> Experience: """Convert a list of messages into an experience in async.""" raise NotImplementedError @@ -61,43 +47,84 @@ def get_available_address(self) -> Tuple[str, int]: return address, port +def _history_recorder(func): + """Decorator to record history of the model calls.""" + + async def async_wrapper(self, *args, **kwargs): + result = await func(self, *args, **kwargs) + if self.enable_history: + self._record_history(result) + return result + + def sync_wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + if self.enable_history: + self._record_history(result) + return result + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + + class ModelWrapper: """A wrapper for the InferenceModel Ray Actor""" # TODO: check model_type inside __init__ - def __init__(self, model: Any, model_type: str = "vllm"): + def __init__(self, model: Any, model_type: str = "vllm", enable_history: bool = False): + assert model_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model - self.use_async = model_type == "vllm_async" self.openai_client: openai.OpenAI = None self.logger = get_logger(__name__) + self.enable_history = enable_history + self.history = [] + + def _record_history(self, exps: Union[Experience, List[Experience]]) -> None: + """Record experiences to history.""" + if isinstance(exps, Experience): + self.history.append(exps) + elif isinstance(exps, list): + self.history.extend(exps) + else: + raise TypeError("Expected Experience or List[Experience], got {}".format(type(exps))) + @_history_recorder def generate(self, prompts: List[str], **kwargs) -> List[Experience]: - if self.use_async: - results = ray.get( - [self.model.generate_async.remote(prompt, **kwargs) for prompt in prompts] - ) - return [exp for exps in results for exp in exps] - else: - return ray.get(self.model.generate.remote(prompts, **kwargs)) + """Generate a list of experiences from a list of prompts.""" + results = ray.get([self.model.generate.remote(prompt, **kwargs) for prompt in prompts]) + return [exp for exps in results for exp in exps] + + @_history_recorder + async def generate_async(self, prompts: List[str], **kwargs) -> List[Experience]: + """Generate a list of experiences from a list of prompts in async.""" + results = await asyncio.gather( + *[self.model.generate.remote(prompt, **kwargs) for prompt in prompts] + ) + return [exp for exps in results for exp in exps] + @_history_recorder def chat(self, messages: List[dict], **kwargs) -> List[Experience]: - if self.use_async: - return ray.get(self.model.chat_async.remote(messages, **kwargs)) - else: - return ray.get(self.model.chat.remote(messages, **kwargs)) + """Generate a list of experiences from a list of messages.""" + return ray.get(self.model.chat.remote(messages, **kwargs)) + + @_history_recorder + async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]: + """Generate a list of experiences from a list of messages in async.""" + return await self.model.chat.remote(messages, **kwargs) def logprobs(self, tokens: List[int]) -> Tensor: - if self.use_async: - return ray.get(self.model.logprobs_async.remote(tokens)) - else: - return ray.get(self.model.logprobs.remote(tokens)) + """Calculate the logprobs of the given tokens.""" + return ray.get(self.model.logprobs.remote(tokens)) + + async def logprobs_async(self, tokens: List[int]) -> Tensor: + """Calculate the logprobs of the given tokens in async.""" + return await self.model.logprobs.remote(tokens) def convert_messages_to_experience(self, messages: List[dict]) -> Experience: """Convert a list of messages into an experience.""" - if self.use_async: - return ray.get(self.model.convert_messages_to_experience_async.remote(messages)) - else: - return ray.get(self.model.convert_messages_to_experience.remote(messages)) + return ray.get(self.model.convert_messages_to_experience.remote(messages)) + + async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience: + """Convert a list of messages into an experience in async.""" + return await self.model.convert_messages_to_experience.remote(messages) @property def model_version(self) -> int: @@ -117,9 +144,9 @@ def get_openai_client(self) -> openai.OpenAI: "OpenAI API server is not running on current model." "Please set `enable_openai_api` to `True`." ) - api_address, model_path = None, None + api_address = None while True: - api_address, model_path = ray.get(self.model.api_server_ready.remote()) + api_address = ray.get(self.model.api_server_ready.remote()) if api_address is not None: break else: @@ -134,5 +161,64 @@ def get_openai_client(self) -> openai.OpenAI: base_url=api_address, api_key="EMPTY", ) - setattr(self.openai_client, "model_path", model_path) # TODO: may be removed + if self.enable_history: + # add a decorator to the openai client to record history + ori_create = self.openai_client.chat.completions.create + + def record_chat_completions(*args, **kwargs): + response = ori_create(*args, **kwargs) + self.history.extend(convert_api_output_to_experience(response)) + return response + + self.openai_client.chat.completions.create = record_chat_completions + setattr(self.openai_client, "model_path", self.openai_client.models.list().data[0].id) return self.openai_client + + def extract_experience_from_history(self, clear_history: bool = True) -> List[Experience]: + """Extract experiences from the history.""" + if not self.enable_history: + raise ValueError("History recording is not enabled.") + exps = [exp for exp in self.history] + if clear_history: + self.history.clear() + return exps + + +def convert_api_output_to_experience( + output, +) -> List[Experience]: + """Convert the API output to a list of experiences.""" + return [ + Experience( + tokens=torch.cat( + ( + torch.tensor(output.prompt_token_ids, dtype=torch.int32), + torch.tensor(choice.token_ids, dtype=torch.int32), + ) + ), + logprobs=torch.cat( + ( + torch.full( + (len(output.prompt_token_ids),), + 0.0, + dtype=torch.float32, + ), + extract_logprobs(choice), + ) + ), + prompt_length=len(output.prompt_token_ids), + prompt_text=None, + response_text=choice.message.content, + ) + for choice in output.choices + ] + + +def extract_logprobs(choice) -> Tensor: + """Extract logprobs from a list of logprob dictionaries.""" + if not hasattr(choice, "logprobs") or choice.logprobs is None: + return torch.tensor([], dtype=torch.float32) + return torch.tensor( + [logprob.logprob for logprob in choice.logprobs.content], + dtype=torch.float32, + ) diff --git a/trinity/common/models/openai_api.py b/trinity/common/models/openai_api.py deleted file mode 100644 index c26b0ca54b..0000000000 --- a/trinity/common/models/openai_api.py +++ /dev/null @@ -1,79 +0,0 @@ -"""OpenAI API server related tools. - -Modified from vllm/entrypoints/openai/api_server.py -""" -import asyncio -import functools - -from vllm.entrypoints.launcher import serve_http -from vllm.entrypoints.openai.api_server import ( - build_app, - create_server_socket, - init_app_state, -) -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import FlexibleArgumentParser, set_ulimit - - -async def run_server_in_ray(args, engine_client): - # workaround to make sure that we bind the port before the engine is set up. - # This avoids race conditions with ray. - # see https://github.com/vllm-project/vllm/issues/8204 - sock_addr = (args.host, args.port) - sock = create_server_socket(sock_addr) - - # workaround to avoid footguns where uvicorn drops requests with too - # many concurrent requests active - set_ulimit() - app = build_app(args) - - vllm_config = await engine_client.get_vllm_config() - await init_app_state(engine_client, vllm_config, app.state, args) - - await patch_and_serve_http(app, sock, args) - - # # NB: Await server shutdown only after the backend context is exited - # try: - # await shutdown_task - # finally: - # sock.close() - - -def dummy_add_signal_handler(self, *args, **kwargs): - # DO NOTHING HERE - pass - - -async def patch_and_serve_http(app, sock, args): - """Patch the add_signal_handler method and serve the app.""" - loop = asyncio.get_event_loop() - original_add_signal_handler = loop.add_signal_handler - loop.add_signal_handler = functools.partial(dummy_add_signal_handler, loop) - - try: - shutdown_task = await serve_http( - app, - sock=sock, - enable_ssl_refresh=args.enable_ssl_refresh, - host=args.host, - port=args.port, - log_level="info", - access_log=True, - timeout_keep_alive=10, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - ) - await shutdown_task - finally: - loop.add_signal_handler = original_add_signal_handler - sock.close() - - -async def run_api_server_in_ray_actor(async_llm, host: str, port: int, model_path: str): - parser = FlexibleArgumentParser(description="Run the OpenAI API server.") - args = make_arg_parser(parser) - args = parser.parse_args(["--host", str(host), "--port", str(port), "--model", model_path]) - print(args) - await run_server_in_ray(args, async_llm) diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py deleted file mode 100644 index f3253bcf4b..0000000000 --- a/trinity/common/models/vllm_async_model.py +++ /dev/null @@ -1,364 +0,0 @@ -"""vLLM AsyncEngine wrapper. - -Modified from Ray python/ray/llm/_internal/batch/stages/vllm_engine_stage.py -""" - -import os -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -import aiohttp -import ray -import torch -import vllm -from vllm.sampling_params import RequestOutputKind - -from trinity.common.config import InferenceModelConfig -from trinity.common.experience import Experience -from trinity.common.models.model import InferenceModel -from trinity.common.models.utils import ( - tokenize_and_mask_messages_default, - tokenize_and_mask_messages_hf, -) -from trinity.utils.log import get_logger - -logger = get_logger(__name__) - - -# TODO: merge into vLLMRolloutModel -# TODO: remove V0 when V1 is stable -class vLLMAysncRolloutModel(InferenceModel): - """Wrapper around the vLLM engine to handle async requests. - - Args: - config (Config): The config. - kwargs (dict): The keyword arguments for the engine. - """ - - def __init__( - self, - config: InferenceModelConfig, - ) -> None: - self.logger = get_logger(__name__) - self.config = config - self.use_v1 = config.use_v1 - if config.tensor_parallel_size != 1: - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices - if not vllm.envs.is_set("VLLM_USE_V1"): - self.logger.info(f"Using vLLM v{int(config.use_v1)} engine") - os.environ["VLLM_USE_V1"] = str(int(config.use_v1)) - if config.use_v1: - os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - self.default_sampling_params = vllm.SamplingParams( - n=1, - temperature=0.0, - max_tokens=config.max_response_tokens, - min_tokens=1, - truncate_prompt_tokens=config.max_prompt_tokens, - skip_special_tokens=True, - include_stop_str_in_output=False, - output_kind=RequestOutputKind.FINAL_ONLY, - logprobs=0, - ) - self.enable_thinking = config.enable_thinking - self.request_id = 0 - max_model_len = None - if config.max_prompt_tokens is not None and config.max_response_tokens is not None: - max_model_len = config.max_prompt_tokens + config.max_response_tokens - engine_args = vllm.AsyncEngineArgs( - model=config.model_path, - enforce_eager=config.enforce_eager, - worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension", - tensor_parallel_size=config.tensor_parallel_size, - seed=config.seed, - distributed_executor_backend=("uni" if config.tensor_parallel_size == 1 else "ray"), - max_model_len=max_model_len, - enable_prefix_caching=config.enable_prefix_caching, - dtype=config.dtype, - trust_remote_code=True, - task="generate", - disable_log_requests=True, - gpu_memory_utilization=config.gpu_memory_utilization, - enable_chunked_prefill=config.enable_chunked_prefill, - # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage - ) - self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) - self.tokenizer = None - self.chat_template = None - if self.config.chat_template: - self.chat_template = self.config.chat_template - if self.chat_template is None or not re.search( - r"\{\%-?\s*generation\s*-?\%\}", self.chat_template - ): - self.logger.warning( - "The provided chat template does not support `return_assitant_tokens_mask`. " - "The default assistant mask method will be used, which may cause performance " - "degradation and lead to incorrect results." - ) - self.action_mask_method = tokenize_and_mask_messages_default - else: - self.action_mask_method = tokenize_and_mask_messages_hf - self.state_dict_meta = None - self.model_version = 0 # TODO: resume the value from the checkpoint - self.api_server_host = None - self.api_server_port = None - - async def chat_async(self, messages: List[Dict], **kwargs) -> List[Experience]: - """Chat with the model with a list of messages in async. - - Args: - messages (List[dict]): The input history messages. - kwargs (dict): A dictionary of sampling parameters. - - Returns: - A list of experiences. - """ - if self.tokenizer is None: - self.tokenizer = await self.async_llm.get_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() - if messages[-1]["role"] == "assistant": - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - continue_final_message=True, - chat_template=self.chat_template, - ) - else: - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - chat_template=self.chat_template, - enable_thinking=self.enable_thinking, - ) - return await self.generate_async(prompt=prompt, **kwargs) - - async def generate_async(self, prompt: str, **kwargs) -> List[Experience]: - """Generate a response from the provided prompt in async. - - Args: - prompt (str): The input prompt. - kwargs (dict): A dictionary of sampling parameters. - - Returns: - A list of experiences. - """ - output = await self._generate_internal(prompt=prompt, **kwargs) - experiences = [ - Experience( - tokens=torch.cat( - ( - torch.tensor(output.prompt_token_ids, dtype=torch.int32), - torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), - ) - ), - logprobs=torch.cat( - ( - torch.full( - (len(output.prompt_token_ids),), - 0.0, - dtype=torch.float32, - ), - torch.tensor( - [ - list(logprob_dict.values())[0].logprob - for logprob_dict in output.outputs[i].logprobs - ], - dtype=torch.float32, - ), - ) - ), - prompt_length=len(output.prompt_token_ids), - prompt_text=output.prompt, - response_text=output.outputs[i].text, - ) - for i in range(len(output.outputs)) - ] - return experiences - - async def logprobs_async(self, token_ids: List[int]) -> torch.Tensor: - """Calculate the logprobs of the given tokens in async.""" - output = await self._generate_internal( - prompt={"prompt_token_ids": token_ids}, - n=1, - max_tokens=1, - prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token - ) - return torch.tensor( - [0] - + [ - list(logprob_dict.values())[0].logprob - for logprob_dict in output.prompt_logprobs[1:] - ], - dtype=torch.float32, - ) - - async def _generate_internal(self, prompt: Any, **kwargs) -> Any: - # Send the request to the LLM engine. - self.request_id += 1 - stream = self.async_llm.generate( - request_id=str(self.request_id), - prompt=prompt, - sampling_params=self._create_sampling_params(**kwargs), - ) - - # Consume the stream until the request is finished. - async for request_output in stream: - if request_output.finished: - # Bypass the original full prompt. - # request_output.prompt = request.prompt - return request_output - - raise RuntimeError("[vLLM] The request is not finished. This should not happen.") - - async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience: - """Convert a list of messages into an experience.""" - if self.tokenizer is None: - self.tokenizer = await self.async_llm.get_tokenizer() - if self.chat_template is None: - self.chat_template = self.tokenizer.get_chat_template() - token_ids, action_mask = self.action_mask_method( - self.tokenizer, messages, self.chat_template - ) - logprobs = await self.logprobs_async(token_ids=token_ids.tolist()) - return Experience( - tokens=token_ids, - prompt_length=len(token_ids), - logprobs=logprobs, - action_mask=action_mask, - ) - - def shutdown(self): - """Shutdown the vLLM v1 engine. This kills child processes forked - by the vLLM engine. If not called, the child processes will be - orphaned and will not be killed when the parent process exits, - and they won't be able to be tracked by Ray anymore. - """ - if hasattr(self.async_llm, "shutdown"): - logger.info("Shutting down vLLM engine") - self.async_llm.shutdown() - - def _create_sampling_params(self, **kwargs): - """Create sampling params.""" - if len(kwargs) == 0: - return self.default_sampling_params - params = self.default_sampling_params.clone() - for k, v in kwargs.items(): - if hasattr(params, k): - setattr(params, k, v) - return params - - async def _collective_rpc( - self, - method: str, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None, - ): - if self.use_v1: - return await self.async_llm.collective_rpc(method, timeout, args, kwargs) - else: - return self.async_llm.engine.model_executor.collective_rpc( - method, timeout, args, kwargs - ) - - async def sync_model( - self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None - ) -> bool: - """Sync model weights to vLLM.""" - if update_weight_args_list is not None: - await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) - await self._collective_rpc("update_weight") - self.logger.info("Sync model weights to vLLM successfully.") - self.model_version = model_version - return True - - async def init_process_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - explorer_name: str, - backend: str = "nccl", - timeout: int = 1200, - update_with_checkpoint: bool = True, - state_dict_meta: dict = None, - ): - return await self._collective_rpc( - "init_process_group", - args=( - master_address, - master_port, - rank_offset, - world_size, - group_name, - backend, - timeout, - update_with_checkpoint, - state_dict_meta, - explorer_name, - ray.get_runtime_context().namespace, - ), - ) - - async def run_api_server(self): - """Run the OpenAI API server in a Ray actor. - - Note: - Do not use `ray.get()` on this method. - This method will run forever until the server is shut down. - """ - if not (self.api_server_host is None or self.api_server_port is None): - raise RuntimeError("API server is already running.") - from trinity.common.models.openai_api import run_api_server_in_ray_actor - - self.api_server_host, self.api_server_port = self.get_available_address() - await run_api_server_in_ray_actor( - self.async_llm, self.api_server_host, self.api_server_port, self.config.model_path - ) - - async def has_api_server(self) -> bool: - return self.config.enable_openai_api - - async def api_server_ready(self) -> Tuple[Union[str, None], Union[str, None]]: - """Check if the OpenAI API server is ready. - - Returns: - api_url (str): The URL of the OpenAI API server. - model_path (str): The path of the model. - """ - if not await self.has_api_server(): - return None, None - try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://{self.api_server_host}:{self.api_server_port}/health" - ) as response: - if response.status == 200: - return ( - f"http://{self.api_server_host}:{self.api_server_port}/v1", - self.config.model_path, - ) - else: - return None, None - except Exception as e: - self.logger.error(e) - return None, None - - async def reset_prefix_cache(self) -> None: - await self.async_llm.reset_prefix_cache() - - def get_model_version(self) -> int: - return self.model_version - - async def sleep(self, level: int = 1) -> None: - await self.async_llm.sleep(level=level) - - async def wake_up(self) -> None: - await self.async_llm.wake_up() diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 643a124f72..01b8135511 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -1,20 +1,15 @@ -# -*- coding: utf-8 -*- -"""vLLM related modules. - -Modified from OpenRLHF openrlhf/trainer/ray/vllm_engine.py +"""A wrapper around the vllm.AsyncEngine to handle async requests. """ -from __future__ import annotations import os import re -import threading -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union +import aiohttp import ray import torch import vllm -from vllm import LLM -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience @@ -25,13 +20,25 @@ ) from trinity.utils.log import get_logger +logger = get_logger(__name__) + +# TODO: remove V0 when V1 is stable class vLLMRolloutModel(InferenceModel): - """Actor for vLLM.""" + """Wrapper around the vLLM engine to handle async requests. + + Args: + config (Config): The config. + kwargs (dict): The keyword arguments for the engine. + """ - def __init__(self, config: InferenceModelConfig): + def __init__( + self, + config: InferenceModelConfig, + ) -> None: self.logger = get_logger(__name__) self.config = config + self.use_v1 = config.use_v1 if config.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices @@ -42,7 +49,7 @@ def __init__(self, config: InferenceModelConfig): os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - self.default_sampling_params = SamplingParams( + self.default_sampling_params = vllm.SamplingParams( n=1, temperature=0.0, max_tokens=config.max_response_tokens, @@ -50,13 +57,15 @@ def __init__(self, config: InferenceModelConfig): truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, + output_kind=RequestOutputKind.FINAL_ONLY, logprobs=0, ) + self.enable_thinking = config.enable_thinking + self.request_id = 0 max_model_len = None if config.max_prompt_tokens is not None and config.max_response_tokens is not None: max_model_len = config.max_prompt_tokens + config.max_response_tokens - self.llm = LLM( - # TODO: check checkpoint path + engine_args = vllm.AsyncEngineArgs( model=config.model_path, enforce_eager=config.enforce_eager, worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension", @@ -67,16 +76,20 @@ def __init__(self, config: InferenceModelConfig): enable_prefix_caching=config.enable_prefix_caching, dtype=config.dtype, trust_remote_code=True, + task="generate", + disable_log_requests=True, gpu_memory_utilization=config.gpu_memory_utilization, enable_chunked_prefill=config.enable_chunked_prefill, - # max_num_batched_tokens=256, + # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage ) - self.tokenizer = self.llm.get_tokenizer() - self.chat_template = self.tokenizer.get_chat_template() - self.enable_thinking = config.enable_thinking + self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) + self.tokenizer = None + self.chat_template = None if self.config.chat_template: self.chat_template = self.config.chat_template - if not re.search(r"\{\%-?\s*generation\s*-?\%\}", self.chat_template): + if self.chat_template is None or not re.search( + r"\{\%-?\s*generation\s*-?\%\}", self.chat_template + ): self.logger.warning( "The provided chat template does not support `return_assitant_tokens_mask`. " "The default assistant mask method will be used, which may cause performance " @@ -85,146 +98,25 @@ def __init__(self, config: InferenceModelConfig): self.action_mask_method = tokenize_and_mask_messages_default else: self.action_mask_method = tokenize_and_mask_messages_hf - self.lock = threading.Lock() self.state_dict_meta = None self.model_version = 0 # TODO: resume the value from the checkpoint + self.api_server_host = None + self.api_server_port = None - def init_process_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - explorer_name: str, - backend: str = "nccl", - timeout: int = 1200, - update_with_checkpoint: bool = True, - state_dict_meta: dict = None, - ): - return self.llm.collective_rpc( - "init_process_group", - args=( - master_address, - master_port, - rank_offset, - world_size, - group_name, - backend, - timeout, - update_with_checkpoint, - state_dict_meta, - explorer_name, - ray.get_runtime_context().namespace, - ), - ) - - def reset_prefix_cache(self): - self.llm.llm_engine.reset_prefix_cache() - - def sleep(self, level=1): - self.llm.sleep(level=level) - - def wake_up(self): - self.llm.wake_up() - - def _create_sampling_params(self, **kwargs): - """Create sampling params.""" - if len(kwargs) == 0: - return self.default_sampling_params - params = self.default_sampling_params.clone() - for k, v in kwargs.items(): - if hasattr(params, k): - setattr(params, k, v) - return params - - def generate(self, prompts: List[str], **kwargs) -> List: - """ - Generate a batch of responses from a batch of prompts. - - Note: - - This method will not apply chat template. - You need to apply chat template before calling this method. + async def chat(self, messages: List[Dict], **kwargs) -> List[Experience]: + """Chat with the model with a list of messages in async. Args: - prompts (List[str]): A list of prompts. + messages (List[dict]): The input history messages. kwargs (dict): A dictionary of sampling parameters. Returns: - List[Experience]: A list of experiences. - - Example: - - >>> # config.algorithm.repeat_times == 2 or kwargs["n"] == 2 - >>> - >>> prompts = [ - >>> "Hello, world!", - >>> "How are you?" - >>> ] - >>> experiences = model.generate(prompts) - >>> print(experiences) - [ - Experience(tokens=tensor()...), # first sequnece for prompts[0] - Experience(tokens=tensor()...), # second sequnece for prompts[0] - Experience(tokens=tensor()...), # first sequence for prompts[1] - Experience(tokens=tensor()...) # second sequence for prompts[1] - ] + A list of experiences. """ - with self.lock: - sampling_params = self._create_sampling_params(**kwargs) - outputs = self.llm.generate(prompts, sampling_params, use_tqdm=False) - experiences = [] - for output in outputs: - for i in range(sampling_params.n): - experiences.append( - Experience( - tokens=torch.cat( - ( - torch.tensor(output.prompt_token_ids, dtype=torch.int32), - torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), - ) - ), - logprobs=torch.cat( - ( - torch.full( - (len(output.prompt_token_ids),), - 0.0, - dtype=torch.float32, - ), - torch.tensor( - [ - list(logprob_dict.values())[0].logprob - for logprob_dict in output.outputs[i].logprobs - ], - dtype=torch.float32, - ), - ) - ), - prompt_length=len(output.prompt_token_ids), - prompt_text=output.prompt, - response_text=output.outputs[i].text, - ) - ) - return experiences - - def chat(self, messages: List[dict], **kwargs) -> List[Experience]: - """Chat with the model with a list of messages. - - Args: - messages (List[dict]): A list of messages. - - Example: - - >>> [ - >>> {"role": "system", "content": "You are a helpful assistant."}, - >>> {"role": "user", "content": "Hello, world!"}, - >>> ] - - Returns: - List[Experience]: A list of experiences containing the response text. - """ - # TODO: support tools and other fields + if self.tokenizer is None: + self.tokenizer = await self.async_llm.get_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() if messages[-1]["role"] == "assistant": prompt = self.tokenizer.apply_chat_template( messages, @@ -240,34 +132,96 @@ def chat(self, messages: List[dict], **kwargs) -> List[Experience]: chat_template=self.chat_template, enable_thinking=self.enable_thinking, ) - return self.generate([prompt], **kwargs) - - def logprobs(self, token_ids: List[int]) -> torch.Tensor: - with self.lock: - outputs = self.llm.generate( - prompts={"prompt_token_ids": token_ids}, - sampling_params=self._create_sampling_params( - n=1, - max_tokens=1, - prompt_logprobs=0, + return await self.generate(prompt=prompt, **kwargs) + + async def generate(self, prompt: str, **kwargs) -> List[Experience]: + """Generate a response from the provided prompt in async. + + Args: + prompt (str): The input prompt. + kwargs (dict): A dictionary of sampling parameters. + + Returns: + A list of experiences. + """ + output = await self._generate_internal(prompt=prompt, **kwargs) + experiences = [ + Experience( + tokens=torch.cat( + ( + torch.tensor(output.prompt_token_ids, dtype=torch.int32), + torch.tensor(output.outputs[i].token_ids, dtype=torch.int32), + ) ), - use_tqdm=False, + logprobs=torch.cat( + ( + torch.full( + (len(output.prompt_token_ids),), + 0.0, + dtype=torch.float32, + ), + torch.tensor( + [ + list(logprob_dict.values())[0].logprob + for logprob_dict in output.outputs[i].logprobs + ], + dtype=torch.float32, + ), + ) + ), + prompt_length=len(output.prompt_token_ids), + prompt_text=output.prompt, + response_text=output.outputs[i].text, ) + for i in range(len(output.outputs)) + ] + return experiences + + async def logprobs(self, token_ids: List[int]) -> torch.Tensor: + """Calculate the logprobs of the given tokens in async.""" + output = await self._generate_internal( + prompt={"prompt_token_ids": token_ids}, + n=1, + max_tokens=1, + prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token + ) return torch.tensor( [0] + [ list(logprob_dict.values())[0].logprob - for logprob_dict in outputs[0].prompt_logprobs[1:] + for logprob_dict in output.prompt_logprobs[1:] ], dtype=torch.float32, ) - def convert_messages_to_experience(self, messages: List[dict]) -> Experience: + async def _generate_internal(self, prompt: Any, **kwargs) -> Any: + # Send the request to the LLM engine. + self.request_id += 1 + stream = self.async_llm.generate( + request_id=str(self.request_id), + prompt=prompt, + sampling_params=self._create_sampling_params(**kwargs), + ) + + # Consume the stream until the request is finished. + async for request_output in stream: + if request_output.finished: + # Bypass the original full prompt. + # request_output.prompt = request.prompt + return request_output + + raise RuntimeError("[vLLM] The request is not finished. This should not happen.") + + async def convert_messages_to_experience(self, messages: List[dict]) -> Experience: """Convert a list of messages into an experience.""" + if self.tokenizer is None: + self.tokenizer = await self.async_llm.get_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() token_ids, action_mask = self.action_mask_method( self.tokenizer, messages, self.chat_template ) - logprobs = self.logprobs(token_ids=token_ids.tolist()) + logprobs = await self.logprobs(token_ids=token_ids.tolist()) return Experience( tokens=token_ids, prompt_length=len(token_ids), @@ -275,19 +229,129 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: action_mask=action_mask, ) - def has_api_server(self) -> bool: - return False + def shutdown(self): + """Shutdown the vLLM v1 engine. This kills child processes forked + by the vLLM engine. If not called, the child processes will be + orphaned and will not be killed when the parent process exits, + and they won't be able to be tracked by Ray anymore. + """ + if hasattr(self.async_llm, "shutdown"): + logger.info("Shutting down vLLM engine") + self.async_llm.shutdown() - def sync_model( + def _create_sampling_params(self, **kwargs): + """Create sampling params.""" + if len(kwargs) == 0: + return self.default_sampling_params + params = self.default_sampling_params.clone() + for k, v in kwargs.items(): + if hasattr(params, k): + setattr(params, k, v) + return params + + async def _collective_rpc( + self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ): + if self.use_v1: + return await self.async_llm.collective_rpc(method, timeout, args, kwargs) + else: + return self.async_llm.engine.model_executor.collective_rpc( + method, timeout, args, kwargs + ) + + async def sync_model( self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None ) -> bool: """Sync model weights to vLLM.""" if update_weight_args_list is not None: - self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) - self._collective_rpc("update_weight") + await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) + await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.model_version = model_version return True + async def init_process_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + explorer_name: str, + backend: str = "nccl", + timeout: int = 1200, + update_with_checkpoint: bool = True, + state_dict_meta: dict = None, + ): + return await self._collective_rpc( + "init_process_group", + args=( + master_address, + master_port, + rank_offset, + world_size, + group_name, + backend, + timeout, + update_with_checkpoint, + state_dict_meta, + explorer_name, + ray.get_runtime_context().namespace, + ), + ) + + async def run_api_server(self): + """Run the OpenAI API server in a Ray actor. + + Note: + Do not use `ray.get()` on this method. + This method will run forever until the server is shut down. + """ + if not (self.api_server_host is None or self.api_server_port is None): + raise RuntimeError("API server is already running.") + from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor + + self.api_server_host, self.api_server_port = self.get_available_address() + await run_api_server_in_ray_actor( + self.async_llm, self.api_server_host, self.api_server_port, self.config.model_path + ) + + async def has_api_server(self) -> bool: + return self.config.enable_openai_api + + async def api_server_ready(self) -> Union[str, None]: + """Check if the OpenAI API server is ready. + + Returns: + api_url (str): The URL of the OpenAI API server. + """ + if not await self.has_api_server(): + return None + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://{self.api_server_host}:{self.api_server_port}/health" + ) as response: + if response.status == 200: + return f"http://{self.api_server_host}:{self.api_server_port}/v1" + else: + return None + except Exception as e: + self.logger.error(e) + return None + + async def reset_prefix_cache(self) -> None: + await self.async_llm.reset_prefix_cache() + def get_model_version(self) -> int: return self.model_version + + async def sleep(self, level: int = 1) -> None: + await self.async_llm.sleep(level=level) + + async def wake_up(self) -> None: + await self.async_llm.wake_up()