From cd0173b4045765b28c583329305c616f048226ca Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 13 Nov 2025 18:33:41 +0800 Subject: [PATCH 01/10] patch vllm --- trinity/common/models/vllm_model.py | 6 +- trinity/common/models/vllm_patch/__init__.py | 13 ++ .../vllm_patch.py => vllm_patch/api_patch.py} | 12 +- .../common/models/vllm_patch/worker_patch.py | 124 ++++++++++++++++++ trinity/common/models/vllm_worker.py | 2 + 5 files changed, 144 insertions(+), 13 deletions(-) create mode 100644 trinity/common/models/vllm_patch/__init__.py rename trinity/common/models/{api/vllm_patch.py => vllm_patch/api_patch.py} (97%) create mode 100644 trinity/common/models/vllm_patch/worker_patch.py diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index eeb90cec12..3bfe70bb93 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -16,13 +16,13 @@ from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience -from trinity.common.models.api.vllm_patch import get_vllm_version from trinity.common.models.mm_utils import ( build_multi_modal_inputs, convert_messages_to_mm_format, ) from trinity.common.models.model import InferenceModel from trinity.common.models.utils import get_action_mask_method +from trinity.common.models.vllm_patch.api_patch import get_vllm_version from trinity.utils.log import get_logger @@ -481,7 +481,9 @@ async def run_api_server(self) -> bool: if self.api_server_host is not None and self.api_server_port is not None: return True # already running - from trinity.common.models.api.vllm_patch import run_api_server_in_ray_actor + from trinity.common.models.vllm_patch.api_patch import ( + run_api_server_in_ray_actor, + ) api_server_host, api_server_port = self.get_available_address() self.api_server = asyncio.create_task( diff --git a/trinity/common/models/vllm_patch/__init__.py b/trinity/common/models/vllm_patch/__init__.py new file mode 100644 index 0000000000..294bb68eb4 --- /dev/null +++ b/trinity/common/models/vllm_patch/__init__.py @@ -0,0 +1,13 @@ +import vllm +from packaging.version import InvalidVersion +from packaging.version import parse as parse_version + + +def get_vllm_version(): + try: + vllm_version = parse_version(vllm.__version__) + except InvalidVersion: + # for self-compiled vllm, + # we cannot parse the version, trait it as the lowest version we support + vllm_version = parse_version("0.8.5") + return vllm_version diff --git a/trinity/common/models/api/vllm_patch.py b/trinity/common/models/vllm_patch/api_patch.py similarity index 97% rename from trinity/common/models/api/vllm_patch.py rename to trinity/common/models/vllm_patch/api_patch.py index 035500591c..71b61f1c5f 100644 --- a/trinity/common/models/api/vllm_patch.py +++ b/trinity/common/models/vllm_patch/api_patch.py @@ -10,7 +10,6 @@ from typing import Optional, Union import vllm -from packaging.version import InvalidVersion from packaging.version import parse as parse_version from pydantic import Field, TypeAdapter from vllm.entrypoints.launcher import serve_http @@ -39,6 +38,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.utils import FlexibleArgumentParser, set_ulimit +from trinity.common.models.vllm_patch import get_vllm_version from trinity.utils.log import get_logger @@ -327,16 +327,6 @@ async def patch_and_serve_http(app, sock, args): sock.close() -def get_vllm_version(): - try: - vllm_version = parse_version(vllm.__version__) - except InvalidVersion: - # for self-compiled vllm, - # we cannot parse the version, trait it as the lowest version we support - vllm_version = parse_version("0.8.5") - return vllm_version - - async def run_api_server_in_ray_actor( async_llm, host: str, diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py new file mode 100644 index 0000000000..80b4e4196e --- /dev/null +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -0,0 +1,124 @@ +from typing import Optional + +import torch +import vllm +from packaging.version import parse as parse_version +from vllm.v1.outputs import LogprobsTensors +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +from trinity.common.models.vllm_patch import get_vllm_version + + +def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): + """Patch vLLM model runner to support prompt logprobs extraction.""" + if get_vllm_version() < parse_version("0.10.0"): + raise ValueError( + f"Unsupported vllm version: {vllm.__version__}. " + "This patch requires vllm version >= 0.10.0, <= 0.11.0." + ) + + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, Optional[LogprobsTensors]]: + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + num_tokens = num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue + + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True + ) + + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) + in_progress_dict[req_id] = logprobs_tensors + + # Determine number of logits to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc.np[req_idx].item() + prompt_hidden_states = hidden_states[offset : offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) + + # PATCH START + temp = request.sampling_params.temperature + if temp is None or temp >= 1e-5: + logits.div_(temp) + # PATCH END + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids + ) + + # Transfer GPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, non_blocking=True) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, non_blocking=True) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + if prompt_logprobs_dict: + self._sync_device() + + return prompt_logprobs_dict + + model_runner._get_prompt_logprobs_dict = _get_prompt_logprobs_dict diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 93d9c0bd48..1cd9a95a2c 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -5,6 +5,7 @@ import torch.distributed from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader +from trinity.common.models.vllm_patch.worker_patch import patch_vllm_prompt_logprobs from trinity.manager.synchronizer import Synchronizer from trinity.utils.distributed import init_process_group from trinity.utils.log import get_logger @@ -56,6 +57,7 @@ def init_process_group( self.synchronizer = Synchronizer.get_actor(namespace=self._namespace) self._checkpoint_converter = None patch_vllm_moe_model_weight_loader(self.model_runner.model) + patch_vllm_prompt_logprobs(self.model_runner) def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" From b4f0977e207e5d524edfe03121057b7ca684d173 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 13 Nov 2025 19:08:30 +0800 Subject: [PATCH 02/10] update patch --- pyproject.toml | 2 +- trinity/common/models/vllm_patch/worker_patch.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 56aec06e96..cfeedad03b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.10,<3.13" dependencies = [ "verl==0.5.0", "ray[default]>=2.48.0", - "vllm>=0.9.1,<=0.11.0", + "vllm>=0.10.0,<=0.11.0", "tensordict", "wandb", "omegaconf", diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index 80b4e4196e..9342538b02 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -1,3 +1,4 @@ +from types import MethodType from typing import Optional import torch @@ -121,4 +122,4 @@ def _get_prompt_logprobs_dict( return prompt_logprobs_dict - model_runner._get_prompt_logprobs_dict = _get_prompt_logprobs_dict + model_runner._get_prompt_logprobs_dict = MethodType(_get_prompt_logprobs_dict, model_runner) From 7dc051aa25117928460a146bd0b0296184a25c18 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 13 Nov 2025 19:15:48 +0800 Subject: [PATCH 03/10] update patch --- trinity/common/models/vllm_patch/worker_patch.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index 9342538b02..1a5a358593 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -23,6 +23,16 @@ def _get_prompt_logprobs_dict( hidden_states: torch.Tensor, num_scheduled_tokens: dict[str, int], ) -> dict[str, Optional[LogprobsTensors]]: + """Patched version of _get_prompt_logprobs_dict. + + This is a monkey-patched version of `_get_prompt_logprobs_dict` from + `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions 0.10.0 to 0.11.0). + + The original function does not apply temperature scaling to logits when + calculating prompt logprobs, which can lead to incorrect logprob values + when the temperature is not 1.0. This patch adds the missing + temperature scaling. + """ num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} @@ -89,7 +99,7 @@ def _get_prompt_logprobs_dict( # PATCH START temp = request.sampling_params.temperature - if temp is None or temp >= 1e-5: + if temp >= 1e-5: logits.div_(temp) # PATCH END From 8abbf0bb5ca3d1fea9feb177da208918eb60fe8e Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 13 Nov 2025 19:23:55 +0800 Subject: [PATCH 04/10] limit vllm version --- pyproject.toml | 2 +- trinity/common/models/vllm_patch/worker_patch.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cfeedad03b..2134840c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.10,<3.13" dependencies = [ "verl==0.5.0", "ray[default]>=2.48.0", - "vllm>=0.10.0,<=0.11.0", + "vllm>=0.10.2,<=0.11.0", "tensordict", "wandb", "omegaconf", diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index 1a5a358593..429cf19086 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -12,10 +12,10 @@ def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): """Patch vLLM model runner to support prompt logprobs extraction.""" - if get_vllm_version() < parse_version("0.10.0"): + if get_vllm_version() < parse_version("0.10.2"): raise ValueError( f"Unsupported vllm version: {vllm.__version__}. " - "This patch requires vllm version >= 0.10.0, <= 0.11.0." + "This patch requires vllm version >= 0.10.2, <= 0.11.0." ) def _get_prompt_logprobs_dict( From cff84bfe987fa496df15bcc8d6db94374915d99b Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 13 Nov 2025 19:41:01 +0800 Subject: [PATCH 05/10] add temperature in logprobs/convert interface --- trinity/common/models/model.py | 43 ++++++++++++++++++++++------- trinity/common/models/vllm_model.py | 19 +++++++++++-- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index f09756511d..a2ecdd90a4 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -31,11 +31,16 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: """Generate experiences from a list of history chat messages in async.""" raise NotImplementedError - async def logprobs(self, tokens: List[int]) -> Tensor: + async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor: """Generate logprobs for a list of tokens in async.""" raise NotImplementedError - async def convert_messages_to_experience(self, messages: List[dict]) -> Experience: + async def convert_messages_to_experience( + self, + messages: List[dict], + tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, + ) -> Experience: """Convert a list of messages into an experience in async.""" raise NotImplementedError @@ -205,21 +210,39 @@ async def chat_mm_async( ) -> List[Experience]: return await self.model.chat_mm.remote(messages, images=images, videos=videos, **kwargs) - def logprobs(self, tokens: List[int]) -> Tensor: + def logprobs(self, tokens: List[int], temperature: Optional[float] = None) -> Tensor: """Calculate the logprobs of the given tokens.""" - return ray.get(self.model.logprobs.remote(tokens)) + return ray.get(self.model.logprobs.remote(tokens, temperature=temperature)) - async def logprobs_async(self, tokens: List[int]) -> Tensor: + async def logprobs_async( + self, tokens: List[int], temperature: Optional[float] = None + ) -> Tensor: """Calculate the logprobs of the given tokens in async.""" - return await self.model.logprobs.remote(tokens) + return await self.model.logprobs.remote(tokens, temperature=temperature) - def convert_messages_to_experience(self, messages: List[dict]) -> Experience: + def convert_messages_to_experience( + self, + messages: List[dict], + tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, + ) -> Experience: """Convert a list of messages into an experience.""" - return ray.get(self.model.convert_messages_to_experience.remote(messages)) + return ray.get( + self.model.convert_messages_to_experience.remote( + messages, tools=tools, temperature=temperature + ) + ) - async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience: + async def convert_messages_to_experience_async( + self, + messages: List[dict], + tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, + ) -> Experience: """Convert a list of messages into an experience in async.""" - return await self.model.convert_messages_to_experience.remote(messages) + return await self.model.convert_messages_to_experience.remote( + messages, tools=tools, temperature=temperature + ) @property def model_version(self) -> int: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 3bfe70bb93..2dbc2e3140 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -307,8 +307,11 @@ async def generate_mm( ] return experiences - async def logprobs( - self, token_ids: List[int], lora_request: LoRARequest = None + async def logprobs( # type: ignore [override] + self, + token_ids: List[int], + lora_request: LoRARequest = None, + temperature: Optional[float] = None, ) -> torch.Tensor: """Calculate the logprobs of the given tokens in async. Please slice the result carefully to align with the actual response length. @@ -316,16 +319,22 @@ async def logprobs( Args: token_ids (List[int]): The input token ids (seq_length). Please make sure the length of it does not exceed `max_model_len - 1`. + lora_request (LoRARequest, optional): The LoRA request. Defaults to None. + temperature (float): The temperature for scaling logits. Returns: A tensor of logprobs (seq_length - 1). """ + temperature = temperature if temperature is not None else self.config.temperature + if temperature is None: + temperature = 1.0 output = await self._generate_internal( prompt={"prompt_token_ids": token_ids}, lora_request=lora_request, n=1, max_tokens=1, prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token + temperature=temperature, ) return torch.tensor( [list(logprob_dict.values())[0].logprob for logprob_dict in output.prompt_logprobs[1:]], @@ -357,6 +366,7 @@ async def convert_messages_to_experience( self, messages: List[dict], tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, ) -> Experience: """Convert a list of messages into an experience.""" if self.tokenizer is None: @@ -370,7 +380,10 @@ async def convert_messages_to_experience( chat_template=self.chat_template, enable_thinking=self.enable_thinking, ) # (seq_length, ), (seq_length, ) - logprobs = await self.logprobs(token_ids=token_ids.tolist()) # (seq_length - 1,) + temperature = temperature if temperature is not None else self.config.temperature + logprobs = await self.logprobs( + token_ids=token_ids.tolist(), temperature=temperature + ) # (seq_length - 1,) return Experience( tokens=token_ids, logprobs=logprobs[prompt_length - 1 :], From 97fce6620d9cf64d16bf8b3e80f29f247e2049e0 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 13 Nov 2025 21:50:11 +0800 Subject: [PATCH 06/10] add vllm tests --- tests/common/vllm_test.py | 83 +++++++++++++++++++++++++++++ trinity/common/models/vllm_model.py | 1 + 2 files changed, 84 insertions(+) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index c7fcd75a87..f893fb935e 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1,5 +1,6 @@ import unittest +import ray import torch from openai import BadRequestError from parameterized import parameterized_class @@ -11,6 +12,7 @@ get_model_path, get_template_config, ) +from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper from trinity.common.models.utils import ( @@ -361,6 +363,87 @@ async def test_api(self): self.assertEqual(len(self.model_wrapper_no_history.history), 0) +class DummySynchronizer: + def __init__(self): + pass + + def do_nothing(self): + pass + + +class TestLogprobs(RayUnittestBaseAysnc): + def setUp(self): + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.explorer.rollout_model.engine_type = "vllm" + self.config.explorer.rollout_model.engine_num = 1 + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.enable_openai_api = True + + self.config.check_and_update() + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) + + async def test_logprobs(self): + # use init process group to apply patches + sync = ( + ray.remote(DummySynchronizer) + .options(name="synchronizer", namespace=self.config.ray_namespace) + .remote() + ) + await sync.__ray_ready__.remote() + await self.model_wrapper.prepare() + master_address, master_port = await self.engines[0].get_available_address.remote() + await self.engines[0].init_process_group.remote( + master_address, + master_port, + world_size=1, + rank_offset=0, + group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + explorer_name=self.config.explorer.name, + timeout=20, + ) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is your name?"}, + ] + response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0] + response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.5, logprobs=True)[0] + self.assertTrue(response_1.logprobs is not None) + self.assertTrue(len(response_1.logprobs) > 0) + self.assertTrue(response_2.logprobs is not None) + self.assertTrue(len(response_2.logprobs) > 0) + logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0) + logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.5) + logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0) + logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.5) + self.assertEqual(logprobs_1.shape, logprobs_2.shape) + self.assertFalse(torch.equal(logprobs_1, logprobs_2)) + logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] + logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] + logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] + logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] + self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) + self.assertFalse(torch.isclose(logprobs_1_prompt, logprobs_2_prompt, atol=1e-2).all()) + self.assertFalse(torch.isclose(logprobs_3_prompt, logprobs_4_prompt, atol=1e-2).all()) + self.assertTrue(torch.isclose(logprobs_1_prompt, logprobs_3_prompt, atol=1e-2).all()) + self.assertTrue(torch.isclose(logprobs_2_prompt, logprobs_4_prompt, atol=1e-2).all()) + logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] + logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] + logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] + logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :] + self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) + self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) + self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) + self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) + self.assertTrue(torch.isclose(response_1.logprobs, logprobs_1_response, atol=1e-2).all()) + self.assertFalse(torch.isclose(response_1.logprobs, logprobs_2_response, atol=1e-2).all()) + self.assertTrue(torch.isclose(response_2.logprobs, logprobs_4_response, atol=1e-2).all()) + self.assertFalse(torch.isclose(response_2.logprobs, logprobs_3_response, atol=1e-2).all()) + + class TestAsyncAPIServer(RayUnittestBaseAysnc): def setUp(self): self.config = get_template_config() diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 2dbc2e3140..d2d3b25c68 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -100,6 +100,7 @@ def __init__( }, disable_log_stats=True, enable_lora=config.enable_lora, + logprobs_mode="processed_logprobs", **config.lora_kwargs, ) if get_vllm_version() > parse_version("0.10.0"): From 1a64af89ef7082ded3cf35cbee011925b518d1e7 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 14 Nov 2025 12:07:40 +0800 Subject: [PATCH 07/10] update commits --- trinity/common/models/vllm_patch/worker_patch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index 429cf19086..dda24cb1d6 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -26,7 +26,8 @@ def _get_prompt_logprobs_dict( """Patched version of _get_prompt_logprobs_dict. This is a monkey-patched version of `_get_prompt_logprobs_dict` from - `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions 0.10.0 to 0.11.0). + `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions + 0.10.2 to 0.11.0). The original function does not apply temperature scaling to logits when calculating prompt logprobs, which can lead to incorrect logprob values From e7b0345f6262be130f1b520f2a1801cb67b17b52 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 14 Nov 2025 12:26:20 +0800 Subject: [PATCH 08/10] fix logprobs tests --- tests/common/vllm_test.py | 30 +++++++++++-------- .../common/models/vllm_patch/worker_patch.py | 2 +- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index f893fb935e..5a9956b18f 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -312,8 +312,9 @@ async def test_api(self): ) self.assertEqual(2, len(response.choices)) self.assertTrue(response.choices[0].logprobs is not None) - self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs)) - self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) + self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs)) + # here we check the 3rd token logprob, because the first two tokens (``,`\n` usually have zero logprob) + self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0) self.assertTrue(hasattr(response, "prompt_token_ids")) self.assertTrue(len(response.prompt_token_ids) > 0) self.assertTrue(hasattr(response.choices[0], "token_ids")) @@ -420,16 +421,18 @@ async def test_logprobs(self): logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0) logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.5) self.assertEqual(logprobs_1.shape, logprobs_2.shape) - self.assertFalse(torch.equal(logprobs_1, logprobs_2)) + self.assertEqual(logprobs_3.shape, logprobs_4.shape) + self.assertFalse(torch.isclose(logprobs_1, logprobs_2, atol=1e-1).all()) + self.assertFalse(torch.isclose(logprobs_3, logprobs_4, atol=1e-1).all()) logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) - self.assertFalse(torch.isclose(logprobs_1_prompt, logprobs_2_prompt, atol=1e-2).all()) - self.assertFalse(torch.isclose(logprobs_3_prompt, logprobs_4_prompt, atol=1e-2).all()) - self.assertTrue(torch.isclose(logprobs_1_prompt, logprobs_3_prompt, atol=1e-2).all()) - self.assertTrue(torch.isclose(logprobs_2_prompt, logprobs_4_prompt, atol=1e-2).all()) + self.assertFalse(torch.isclose(logprobs_1_prompt, logprobs_2_prompt, atol=1e-1).all()) + self.assertFalse(torch.isclose(logprobs_3_prompt, logprobs_4_prompt, atol=1e-1).all()) + self.assertTrue(torch.isclose(logprobs_1_prompt, logprobs_3_prompt, atol=1e-1).all()) + self.assertTrue(torch.isclose(logprobs_2_prompt, logprobs_4_prompt, atol=1e-1).all()) logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] @@ -438,10 +441,10 @@ async def test_logprobs(self): self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) - self.assertTrue(torch.isclose(response_1.logprobs, logprobs_1_response, atol=1e-2).all()) - self.assertFalse(torch.isclose(response_1.logprobs, logprobs_2_response, atol=1e-2).all()) - self.assertTrue(torch.isclose(response_2.logprobs, logprobs_4_response, atol=1e-2).all()) - self.assertFalse(torch.isclose(response_2.logprobs, logprobs_3_response, atol=1e-2).all()) + self.assertTrue(torch.isclose(response_1.logprobs, logprobs_1_response, atol=1e-1).all()) + self.assertFalse(torch.isclose(response_1.logprobs, logprobs_2_response, atol=1e-1).all()) + self.assertTrue(torch.isclose(response_2.logprobs, logprobs_4_response, atol=1e-1).all()) + self.assertFalse(torch.isclose(response_2.logprobs, logprobs_3_response, atol=1e-1).all()) class TestAsyncAPIServer(RayUnittestBaseAysnc): @@ -486,8 +489,9 @@ async def test_api_async(self): ) self.assertEqual(2, len(response.choices)) self.assertTrue(response.choices[0].logprobs is not None) - self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs)) - self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) + self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs)) + # here we check the 3rd token logprob, because the first two tokens (``,`\n` usually have zero logprob) + self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0) self.assertTrue(hasattr(response, "prompt_token_ids")) self.assertTrue(len(response.prompt_token_ids) > 0) self.assertTrue(hasattr(response.choices[0], "token_ids")) diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index dda24cb1d6..ebe9d47ac3 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -26,7 +26,7 @@ def _get_prompt_logprobs_dict( """Patched version of _get_prompt_logprobs_dict. This is a monkey-patched version of `_get_prompt_logprobs_dict` from - `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions + `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions 0.10.2 to 0.11.0). The original function does not apply temperature scaling to logits when From 1ddb8365cc914088e71f9bbfbdb05b3adb179f64 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 14 Nov 2025 13:57:30 +0800 Subject: [PATCH 09/10] fix vllm logprobs --- tests/common/vllm_test.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 5a9956b18f..f60d39ccf9 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -411,28 +411,28 @@ async def test_logprobs(self): {"role": "user", "content": "What is your name?"}, ] response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0] - response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.5, logprobs=True)[0] + response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.8, logprobs=True)[0] self.assertTrue(response_1.logprobs is not None) self.assertTrue(len(response_1.logprobs) > 0) self.assertTrue(response_2.logprobs is not None) self.assertTrue(len(response_2.logprobs) > 0) logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0) - logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.5) + logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.8) logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0) - logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.5) + logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8) self.assertEqual(logprobs_1.shape, logprobs_2.shape) self.assertEqual(logprobs_3.shape, logprobs_4.shape) - self.assertFalse(torch.isclose(logprobs_1, logprobs_2, atol=1e-1).all()) - self.assertFalse(torch.isclose(logprobs_3, logprobs_4, atol=1e-1).all()) + self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4)) logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) - self.assertFalse(torch.isclose(logprobs_1_prompt, logprobs_2_prompt, atol=1e-1).all()) - self.assertFalse(torch.isclose(logprobs_3_prompt, logprobs_4_prompt, atol=1e-1).all()) - self.assertTrue(torch.isclose(logprobs_1_prompt, logprobs_3_prompt, atol=1e-1).all()) - self.assertTrue(torch.isclose(logprobs_2_prompt, logprobs_4_prompt, atol=1e-1).all()) + self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4)) + self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4)) + self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4)) logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] @@ -441,10 +441,10 @@ async def test_logprobs(self): self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) - self.assertTrue(torch.isclose(response_1.logprobs, logprobs_1_response, atol=1e-1).all()) - self.assertFalse(torch.isclose(response_1.logprobs, logprobs_2_response, atol=1e-1).all()) - self.assertTrue(torch.isclose(response_2.logprobs, logprobs_4_response, atol=1e-1).all()) - self.assertFalse(torch.isclose(response_2.logprobs, logprobs_3_response, atol=1e-1).all()) + self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5)) + self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5)) + self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8)) + self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8)) class TestAsyncAPIServer(RayUnittestBaseAysnc): From 9164c86c29085b0963de1dc2c90da7afd0d91212 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 14 Nov 2025 15:11:41 +0800 Subject: [PATCH 10/10] update sphinx timeout --- .github/workflows/sphinx-doc.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/sphinx-doc.yaml b/.github/workflows/sphinx-doc.yaml index c92cfb6d65..fc7963bbbf 100644 --- a/.github/workflows/sphinx-doc.yaml +++ b/.github/workflows/sphinx-doc.yaml @@ -11,7 +11,7 @@ on: jobs: pages: - timeout-minutes: 20 + timeout-minutes: 30 runs-on: ${{ matrix.os }} strategy: matrix: