diff --git a/.github/workflows/sphinx-doc.yaml b/.github/workflows/sphinx-doc.yaml index c92cfb6d65..fc7963bbbf 100644 --- a/.github/workflows/sphinx-doc.yaml +++ b/.github/workflows/sphinx-doc.yaml @@ -11,7 +11,7 @@ on: jobs: pages: - timeout-minutes: 20 + timeout-minutes: 30 runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/pyproject.toml b/pyproject.toml index 56aec06e96..2134840c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.10,<3.13" dependencies = [ "verl==0.5.0", "ray[default]>=2.48.0", - "vllm>=0.9.1,<=0.11.0", + "vllm>=0.10.2,<=0.11.0", "tensordict", "wandb", "omegaconf", diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index c7fcd75a87..f60d39ccf9 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1,5 +1,6 @@ import unittest +import ray import torch from openai import BadRequestError from parameterized import parameterized_class @@ -11,6 +12,7 @@ get_model_path, get_template_config, ) +from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper from trinity.common.models.utils import ( @@ -310,8 +312,9 @@ async def test_api(self): ) self.assertEqual(2, len(response.choices)) self.assertTrue(response.choices[0].logprobs is not None) - self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs)) - self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) + self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs)) + # here we check the 3rd token logprob, because the first two tokens (``,`\n` usually have zero logprob) + self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0) self.assertTrue(hasattr(response, "prompt_token_ids")) self.assertTrue(len(response.prompt_token_ids) > 0) self.assertTrue(hasattr(response.choices[0], "token_ids")) @@ -361,6 +364,89 @@ async def test_api(self): self.assertEqual(len(self.model_wrapper_no_history.history), 0) +class DummySynchronizer: + def __init__(self): + pass + + def do_nothing(self): + pass + + +class TestLogprobs(RayUnittestBaseAysnc): + def setUp(self): + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.explorer.rollout_model.engine_type = "vllm" + self.config.explorer.rollout_model.engine_num = 1 + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.enable_openai_api = True + + self.config.check_and_update() + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) + + async def test_logprobs(self): + # use init process group to apply patches + sync = ( + ray.remote(DummySynchronizer) + .options(name="synchronizer", namespace=self.config.ray_namespace) + .remote() + ) + await sync.__ray_ready__.remote() + await self.model_wrapper.prepare() + master_address, master_port = await self.engines[0].get_available_address.remote() + await self.engines[0].init_process_group.remote( + master_address, + master_port, + world_size=1, + rank_offset=0, + group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + explorer_name=self.config.explorer.name, + timeout=20, + ) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is your name?"}, + ] + response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0] + response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.8, logprobs=True)[0] + self.assertTrue(response_1.logprobs is not None) + self.assertTrue(len(response_1.logprobs) > 0) + self.assertTrue(response_2.logprobs is not None) + self.assertTrue(len(response_2.logprobs) > 0) + logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0) + logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.8) + logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0) + logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8) + self.assertEqual(logprobs_1.shape, logprobs_2.shape) + self.assertEqual(logprobs_3.shape, logprobs_4.shape) + self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4)) + logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] + logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] + logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] + logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] + self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) + self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4)) + self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4)) + self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4)) + logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] + logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] + logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] + logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :] + self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) + self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) + self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) + self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) + self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5)) + self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5)) + self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8)) + self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8)) + + class TestAsyncAPIServer(RayUnittestBaseAysnc): def setUp(self): self.config = get_template_config() @@ -403,8 +489,9 @@ async def test_api_async(self): ) self.assertEqual(2, len(response.choices)) self.assertTrue(response.choices[0].logprobs is not None) - self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs)) - self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) + self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs)) + # here we check the 3rd token logprob, because the first two tokens (``,`\n` usually have zero logprob) + self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0) self.assertTrue(hasattr(response, "prompt_token_ids")) self.assertTrue(len(response.prompt_token_ids) > 0) self.assertTrue(hasattr(response.choices[0], "token_ids")) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index f09756511d..a2ecdd90a4 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -31,11 +31,16 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: """Generate experiences from a list of history chat messages in async.""" raise NotImplementedError - async def logprobs(self, tokens: List[int]) -> Tensor: + async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor: """Generate logprobs for a list of tokens in async.""" raise NotImplementedError - async def convert_messages_to_experience(self, messages: List[dict]) -> Experience: + async def convert_messages_to_experience( + self, + messages: List[dict], + tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, + ) -> Experience: """Convert a list of messages into an experience in async.""" raise NotImplementedError @@ -205,21 +210,39 @@ async def chat_mm_async( ) -> List[Experience]: return await self.model.chat_mm.remote(messages, images=images, videos=videos, **kwargs) - def logprobs(self, tokens: List[int]) -> Tensor: + def logprobs(self, tokens: List[int], temperature: Optional[float] = None) -> Tensor: """Calculate the logprobs of the given tokens.""" - return ray.get(self.model.logprobs.remote(tokens)) + return ray.get(self.model.logprobs.remote(tokens, temperature=temperature)) - async def logprobs_async(self, tokens: List[int]) -> Tensor: + async def logprobs_async( + self, tokens: List[int], temperature: Optional[float] = None + ) -> Tensor: """Calculate the logprobs of the given tokens in async.""" - return await self.model.logprobs.remote(tokens) + return await self.model.logprobs.remote(tokens, temperature=temperature) - def convert_messages_to_experience(self, messages: List[dict]) -> Experience: + def convert_messages_to_experience( + self, + messages: List[dict], + tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, + ) -> Experience: """Convert a list of messages into an experience.""" - return ray.get(self.model.convert_messages_to_experience.remote(messages)) + return ray.get( + self.model.convert_messages_to_experience.remote( + messages, tools=tools, temperature=temperature + ) + ) - async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience: + async def convert_messages_to_experience_async( + self, + messages: List[dict], + tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, + ) -> Experience: """Convert a list of messages into an experience in async.""" - return await self.model.convert_messages_to_experience.remote(messages) + return await self.model.convert_messages_to_experience.remote( + messages, tools=tools, temperature=temperature + ) @property def model_version(self) -> int: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index eeb90cec12..d2d3b25c68 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -16,13 +16,13 @@ from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience -from trinity.common.models.api.vllm_patch import get_vllm_version from trinity.common.models.mm_utils import ( build_multi_modal_inputs, convert_messages_to_mm_format, ) from trinity.common.models.model import InferenceModel from trinity.common.models.utils import get_action_mask_method +from trinity.common.models.vllm_patch.api_patch import get_vllm_version from trinity.utils.log import get_logger @@ -100,6 +100,7 @@ def __init__( }, disable_log_stats=True, enable_lora=config.enable_lora, + logprobs_mode="processed_logprobs", **config.lora_kwargs, ) if get_vllm_version() > parse_version("0.10.0"): @@ -307,8 +308,11 @@ async def generate_mm( ] return experiences - async def logprobs( - self, token_ids: List[int], lora_request: LoRARequest = None + async def logprobs( # type: ignore [override] + self, + token_ids: List[int], + lora_request: LoRARequest = None, + temperature: Optional[float] = None, ) -> torch.Tensor: """Calculate the logprobs of the given tokens in async. Please slice the result carefully to align with the actual response length. @@ -316,16 +320,22 @@ async def logprobs( Args: token_ids (List[int]): The input token ids (seq_length). Please make sure the length of it does not exceed `max_model_len - 1`. + lora_request (LoRARequest, optional): The LoRA request. Defaults to None. + temperature (float): The temperature for scaling logits. Returns: A tensor of logprobs (seq_length - 1). """ + temperature = temperature if temperature is not None else self.config.temperature + if temperature is None: + temperature = 1.0 output = await self._generate_internal( prompt={"prompt_token_ids": token_ids}, lora_request=lora_request, n=1, max_tokens=1, prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token + temperature=temperature, ) return torch.tensor( [list(logprob_dict.values())[0].logprob for logprob_dict in output.prompt_logprobs[1:]], @@ -357,6 +367,7 @@ async def convert_messages_to_experience( self, messages: List[dict], tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, ) -> Experience: """Convert a list of messages into an experience.""" if self.tokenizer is None: @@ -370,7 +381,10 @@ async def convert_messages_to_experience( chat_template=self.chat_template, enable_thinking=self.enable_thinking, ) # (seq_length, ), (seq_length, ) - logprobs = await self.logprobs(token_ids=token_ids.tolist()) # (seq_length - 1,) + temperature = temperature if temperature is not None else self.config.temperature + logprobs = await self.logprobs( + token_ids=token_ids.tolist(), temperature=temperature + ) # (seq_length - 1,) return Experience( tokens=token_ids, logprobs=logprobs[prompt_length - 1 :], @@ -481,7 +495,9 @@ async def run_api_server(self) -> bool: if self.api_server_host is not None and self.api_server_port is not None: return True # already running - from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor + from trinity.common.models.vllm_patch.api_patch import ( + run_api_server_in_ray_actor, + ) api_server_host, api_server_port = self.get_available_address() self.api_server = asyncio.create_task( diff --git a/trinity/common/models/vllm_patch/__init__.py b/trinity/common/models/vllm_patch/__init__.py new file mode 100644 index 0000000000..294bb68eb4 --- /dev/null +++ b/trinity/common/models/vllm_patch/__init__.py @@ -0,0 +1,13 @@ +import vllm +from packaging.version import InvalidVersion +from packaging.version import parse as parse_version + + +def get_vllm_version(): + try: + vllm_version = parse_version(vllm.__version__) + except InvalidVersion: + # for self-compiled vllm, + # we cannot parse the version, trait it as the lowest version we support + vllm_version = parse_version("0.8.5") + return vllm_version diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/vllm_patch/api_patch.py similarity index 97% rename from trinity/common/models/api/vllm_patch.py rename to trinity/common/models/vllm_patch/api_patch.py index 035500591c..71b61f1c5f 100644 --- a/trinity/common/models/api/vllm_patch.py +++ b/trinity/common/models/vllm_patch/api_patch.py @@ -10,7 +10,6 @@ from typing import Optional, Union import vllm -from packaging.version import InvalidVersion from packaging.version import parse as parse_version from pydantic import Field, TypeAdapter from vllm.entrypoints.launcher import serve_http @@ -39,6 +38,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.utils import FlexibleArgumentParser, set_ulimit +from trinity.common.models.vllm_patch import get_vllm_version from trinity.utils.log import get_logger @@ -327,16 +327,6 @@ async def patch_and_serve_http(app, sock, args): sock.close() -def get_vllm_version(): - try: - vllm_version = parse_version(vllm.__version__) - except InvalidVersion: - # for self-compiled vllm, - # we cannot parse the version, trait it as the lowest version we support - vllm_version = parse_version("0.8.5") - return vllm_version - - async def run_api_server_in_ray_actor( async_llm, host: str, diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py new file mode 100644 index 0000000000..ebe9d47ac3 --- /dev/null +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -0,0 +1,136 @@ +from types import MethodType +from typing import Optional + +import torch +import vllm +from packaging.version import parse as parse_version +from vllm.v1.outputs import LogprobsTensors +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +from trinity.common.models.vllm_patch import get_vllm_version + + +def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): + """Patch vLLM model runner to support prompt logprobs extraction.""" + if get_vllm_version() < parse_version("0.10.2"): + raise ValueError( + f"Unsupported vllm version: {vllm.__version__}. " + "This patch requires vllm version >= 0.10.2, <= 0.11.0." + ) + + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, Optional[LogprobsTensors]]: + """Patched version of _get_prompt_logprobs_dict. + + This is a monkey-patched version of `_get_prompt_logprobs_dict` from + `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions + 0.10.2 to 0.11.0). + + The original function does not apply temperature scaling to logits when + calculating prompt logprobs, which can lead to incorrect logprob values + when the temperature is not 1.0. This patch adds the missing + temperature scaling. + """ + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + num_tokens = num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue + + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True + ) + + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) + in_progress_dict[req_id] = logprobs_tensors + + # Determine number of logits to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc.np[req_idx].item() + prompt_hidden_states = hidden_states[offset : offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) + + # PATCH START + temp = request.sampling_params.temperature + if temp >= 1e-5: + logits.div_(temp) + # PATCH END + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids + ) + + # Transfer GPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, non_blocking=True) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, non_blocking=True) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + if prompt_logprobs_dict: + self._sync_device() + + return prompt_logprobs_dict + + model_runner._get_prompt_logprobs_dict = MethodType(_get_prompt_logprobs_dict, model_runner) diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 93d9c0bd48..1cd9a95a2c 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -5,6 +5,7 @@ import torch.distributed from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader +from trinity.common.models.vllm_patch.worker_patch import patch_vllm_prompt_logprobs from trinity.manager.synchronizer import Synchronizer from trinity.utils.distributed import init_process_group from trinity.utils.log import get_logger @@ -56,6 +57,7 @@ def init_process_group( self.synchronizer = Synchronizer.get_actor(namespace=self._namespace) self._checkpoint_converter = None patch_vllm_moe_model_weight_loader(self.model_runner.model) + patch_vllm_prompt_logprobs(self.model_runner) def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)"""