From 24660de7c10a17abe534f72491e34d026b39290f Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 11:25:40 +0800 Subject: [PATCH 01/18] support vllm v0.12 --- pyproject.toml | 4 +- scripts/docker/Dockerfile.uv | 7 +- tests/common/vllm_test.py | 1 + trinity/common/models/model.py | 36 +++- trinity/common/models/vllm_model.py | 86 +++++---- trinity/common/models/vllm_patch/api_patch.py | 9 +- .../common/models/vllm_patch/api_patch_v12.py | 173 ++++++++++++++++++ .../common/models/vllm_patch/worker_patch.py | 136 +++++++++++++- 8 files changed, 399 insertions(+), 53 deletions(-) create mode 100644 trinity/common/models/vllm_patch/api_patch_v12.py diff --git a/pyproject.toml b/pyproject.toml index 225079c305..64dfabd3da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ classifiers = [ requires-python = ">=3.10,<3.13" dependencies = [ "verl==0.5.0", - "ray[default]>=2.48.0", - "vllm>=0.10.2,<=0.11.0", + "ray[default]>=2.50.0", + "vllm>=0.10.2,<=0.12.0", "tensordict", "wandb", "omegaconf", diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index 3d40a279c5..8718986a9c 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -22,15 +22,14 @@ RUN chmod 1777 /tmp && apt update && apt install -y \ && ln -sf /usr/bin/python3 /usr/bin/python \ && ln -sf /usr/bin/pip3 /usr/bin/pip -# For Aliyun users: update pip mirror to aliyun to speed up pip install -# ENV PIP_INDEX_URL=http://mirrors.cloud.aliyuncs.com/pypi/simple/ -# ENV PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com - ENV VIRTUAL_ENV=/opt/venv # copy the Trinity-RFT dir into the workspace COPY . . +# For Aliyun users: update pip mirror to aliyun to speed up pip install +# ENV UV_DEFAULT_INDEX=http://mirrors.cloud.aliyuncs.com/pypi/simple/ + # Install uv RUN pip install uv && uv venv /opt/venv --python=python3.12 diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index db7156cd4d..2f84f6e8da 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -496,6 +496,7 @@ def setUp(self): self.config.explorer.rollout_model.tensor_parallel_size = 1 self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE self.config.explorer.rollout_model.enable_openai_api = True + self.config.explorer.rollout_model.enable_log_requests = True self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index e08062f401..157e4416f2 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -3,7 +3,6 @@ import asyncio import socket from abc import ABC, abstractmethod -from functools import partial from typing import Dict, List, Optional, Sequence, Tuple, Union import httpx @@ -96,7 +95,17 @@ def __init__( engine_type: str = "vllm", enable_lora: bool = False, enable_history: bool = False, + enbale_thinking: Optional[bool] = None, ): + """Initialize the ModelWrapper. + + Args: + model (InferenceModel): The inference model Ray actor. + engine_type (str): The type of the model engine. Default to "vllm". + enable_lora (bool): Whether to enable LoRA. Default to False. + enable_history (bool): Whether to enable history recording. Default to False. + enbale_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models. + """ assert engine_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model self.api_address: str = None @@ -105,6 +114,7 @@ def __init__( self.logger = get_logger(__name__) self.enable_lora = enable_lora self.enable_history = enable_history + self.enable_thinking = enbale_thinking self.history = [] self.status = RunningStatus.RUNNING self.workflow_state: Dict = {} @@ -303,10 +313,17 @@ def get_openai_client(self) -> openai.OpenAI: ) if self.enable_history: # add a decorator to the openai client to record history - ori_create = partial(self.openai_client.chat.completions.create, logprobs=True) + + ori_create = self.openai_client.chat.completions.create def record_chat_completions(*args, **kwargs): - response = ori_create(*args, **kwargs) + extra_body = kwargs.get("extra_body", {}) + if self.enable_thinking is not None: + if "chat_template_kwargs" not in extra_body: + extra_body["chat_template_kwargs"] = {} + extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking + extra_body["return_token_ids"] = True + response = ori_create(*args, extra_body=extra_body, logprobs=True, **kwargs) self.history.extend(convert_api_output_to_experience(response)) return response @@ -333,10 +350,19 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: ) if self.enable_history: # add a decorator to the openai client to record history - ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True) + + ori_create = self.openai_async_client.chat.completions.create async def record_chat_completions(*args, **kwargs): - response = await ori_create(*args, **kwargs) + kwargs.pop("logprobs", True) + extra_body = kwargs.get("extra_body", {}) + if self.enable_thinking is not None: + if "chat_template_kwargs" not in extra_body: + extra_body["chat_template_kwargs"] = {} + extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking + extra_body["return_token_ids"] = True + # self.logger.info("args: %s, kwargs: %s, extra_body: %s", args, kwargs, extra_body) + response = await ori_create(*args, extra_body=extra_body, logprobs=True, **kwargs) self.history.extend(convert_api_output_to_experience(response)) return response diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index dbddd53a5f..dc36bb1d00 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -7,12 +7,9 @@ import numpy as np import ray import torch -import vllm from packaging.version import parse as parse_version from PIL import Image from transformers import AutoProcessor -from vllm.lora.request import LoRARequest -from vllm.sampling_params import RequestOutputKind from trinity.common.config import InferenceModelConfig from trinity.common.experience import Experience @@ -22,7 +19,7 @@ ) from trinity.common.models.model import InferenceModel from trinity.common.models.utils import get_action_mask_method -from trinity.common.models.vllm_patch.api_patch import get_vllm_version +from trinity.common.models.vllm_patch import get_vllm_version from trinity.utils.log import get_logger @@ -38,20 +35,24 @@ def __init__( self, config: InferenceModelConfig, ) -> None: + import vllm + from vllm.sampling_params import RequestOutputKind + self.logger = get_logger(__name__) + self.vllm_version = get_vllm_version() self.config = config self.use_v1 = config.use_v1 if config.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices - if not vllm.envs.is_set("VLLM_USE_V1"): + if self.vllm_version <= parse_version("0.11.0") and not vllm.envs.is_set("VLLM_USE_V1"): self.logger.info(f"Using vLLM v{int(config.use_v1)} engine") os.environ["VLLM_USE_V1"] = str(int(config.use_v1)) if config.use_v1: os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - if get_vllm_version() >= parse_version("0.11.0"): + if self.vllm_version >= parse_version("0.11.0"): os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" if not config.enforce_eager: # To avoid torch compile conflicts when multiple model are started simultaneously. @@ -99,7 +100,6 @@ def __init__( trust_remote_code=True, task="generate", gpu_memory_utilization=config.gpu_memory_utilization, - enable_chunked_prefill=config.enable_chunked_prefill, # max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage override_generation_config={ # TODO: find a way to unittest this "temperature": config.temperature, @@ -114,12 +114,13 @@ def __init__( **rope_kwargs, **config.lora_kwargs, ) - if get_vllm_version() > parse_version("0.10.0"): + if self.vllm_version > parse_version("0.10.0"): engine_args.enable_log_requests = config.enable_log_requests else: engine_args.disable_log_requests = not config.enable_log_requests - if get_vllm_version() >= parse_version("0.11.0"): + if self.vllm_version >= parse_version("0.11.0"): engine_args.reasoning_parser = config.reasoning_parser + self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args) self.processor = None self.tokenizer = None @@ -157,9 +158,7 @@ async def prepare( await self.run_api_server() self._prepared = True - async def chat( - self, messages: List[Dict], lora_request: LoRARequest = None, **kwargs - ) -> Sequence[Experience]: + async def chat(self, messages: List[Dict], lora_request=None, **kwargs) -> Sequence[Experience]: """Chat with the model with a list of messages in async. Args: @@ -190,9 +189,7 @@ async def chat( ) return await self.generate(prompt=prompt, lora_request=lora_request, **kwargs) - async def generate( - self, prompt: str, lora_request: LoRARequest = None, **kwargs - ) -> Sequence[Experience]: + async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[Experience]: """Generate a response from the provided prompt in async. Args: @@ -361,7 +358,7 @@ async def generate_mm( async def logprobs( # type: ignore [override] self, token_ids: List[int], - lora_request: LoRARequest = None, + lora_request=None, temperature: Optional[float] = None, ) -> torch.Tensor: """Calculate the logprobs of the given tokens in async. Please slice the result carefully @@ -392,9 +389,7 @@ async def logprobs( # type: ignore [override] dtype=torch.float32, ) - async def _generate_internal( - self, prompt: Any, lora_request: LoRARequest = None, **kwargs - ) -> Any: + async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any: # Send the request to the LLM engine. self.request_id += 1 stream = self.async_llm.generate( @@ -561,23 +556,42 @@ async def run_api_server(self) -> bool: self.logger.info("OpenAI API server is already running. Skipping...") return True # already running - from trinity.common.models.vllm_patch.api_patch import ( - run_api_server_in_ray_actor, - ) - api_server_host, api_server_port = self.get_available_address() - self.api_server = asyncio.create_task( - run_api_server_in_ray_actor( - self.async_llm, - api_server_host, - api_server_port, - self.config.model_path, # type: ignore [arg-type] - self.config.enable_auto_tool_choice, - self.config.tool_call_parser, - self.config.reasoning_parser, - self.config.enable_log_requests, + if self.vllm_version <= parse_version("0.11.0"): + from trinity.common.models.vllm_patch.api_patch import ( + run_api_server_in_ray_actor, + ) + + self.api_server = asyncio.create_task( + run_api_server_in_ray_actor( + self.async_llm, + api_server_host, + api_server_port, + self.config.model_path, # type: ignore [arg-type] + self.config.enable_auto_tool_choice, + self.config.tool_call_parser, + self.config.reasoning_parser, + self.config.enable_log_requests, + ) + ) + else: + from trinity.common.models.vllm_patch.api_patch_v12 import ( + run_api_server_in_ray_actor_v12, + ) + + self.api_server = asyncio.create_task( + run_api_server_in_ray_actor_v12( + self.async_llm, + api_server_host, + api_server_port, + self.config.model_path, # type: ignore [arg-type] + logger=self.logger, + enable_auto_tool_choice=self.config.enable_auto_tool_choice, + tool_call_parser=self.config.tool_call_parser, + reasoning_parser=self.config.reasoning_parser, + enable_log_requests=self.config.enable_log_requests, + ) ) - ) self.api_server_host = api_server_host self.api_server_port = api_server_port return True @@ -604,7 +618,9 @@ def get_model_version(self) -> int: def get_model_path(self) -> str: return self.config.model_path # type: ignore [return-value] - def get_lora_request(self, lora_path: Optional[str] = None) -> LoRARequest: + def get_lora_request(self, lora_path: Optional[str] = None) -> Any: + from vllm.lora.request import LoRARequest + assert self.config.lora_modules is not None lora_request = LoRARequest(**self.config.lora_modules[0]) if lora_path is not None: diff --git a/trinity/common/models/vllm_patch/api_patch.py b/trinity/common/models/vllm_patch/api_patch.py index 0036b0956c..623f9c04af 100644 --- a/trinity/common/models/vllm_patch/api_patch.py +++ b/trinity/common/models/vllm_patch/api_patch.py @@ -1,4 +1,4 @@ -"""Patch for vllm OpenAI API server. +"""Patch for vllm OpenAI API server. Only for vllm versions >=0.8.5, <=0.11.0. 1. Mocks the `add_signal_handler` method to do nothing. 2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. @@ -51,7 +51,6 @@ class PatchedChatCompletionResponse(ChatCompletionResponse): choices: list[PatchedChatCompletionResponseChoice] = list[ChatCompletionResponseChoice] -# TODO: add patch to stream generator async def chat_completion_full_generator( # noqa C901 self, request, @@ -304,7 +303,11 @@ async def patch_and_serve_http(app, sock, args): loop = asyncio.get_event_loop() original_add_signal_handler = loop.add_signal_handler loop.add_signal_handler = functools.partial(dummy_add_signal_handler, loop) - OpenAIServingChat.chat_completion_full_generator = chat_completion_full_generator + vllm_version = get_vllm_version() + + # from 0.10.2, vllm added token_ids to ChatCompletionResponseChoice, so no need to patch + if vllm_version < parse_version("0.10.2"): + OpenAIServingChat.chat_completion_full_generator = chat_completion_full_generator try: shutdown_task = await serve_http( diff --git a/trinity/common/models/vllm_patch/api_patch_v12.py b/trinity/common/models/vllm_patch/api_patch_v12.py new file mode 100644 index 0000000000..da43bfd9a7 --- /dev/null +++ b/trinity/common/models/vllm_patch/api_patch_v12.py @@ -0,0 +1,173 @@ +"""Patch for vllm OpenAI API server. Only for vllm versions >=0.8.5, <=0.11.0. + +1. Mocks the `add_signal_handler` method to do nothing. +2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. +""" +import logging +from typing import Optional + +import vllm +import vllm.envs as envs +from packaging.version import parse as parse_version +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.openai.api_server import ( + build_app, + create_server_socket, + create_server_unix_socket, + init_app_state, + validate_api_server_args, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import log_non_default_args +from vllm.reasoning import ReasoningParserManager +from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.network_utils import is_valid_ipv6_address +from vllm.utils.system_utils import set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +from trinity.common.models.vllm_patch import get_vllm_version + + +def setup_server_in_ray(args, logger): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + + logger.info("vLLM API server version %s", VLLM_VERSION) + log_non_default_args(args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: + ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) + + validate_api_server_args(args) + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + if args.uds: + sock = create_server_unix_socket(args.uds) + else: + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + if args.uds: + listen_address = f"unix:{args.uds}" + else: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + return listen_address, sock + + +async def run_server_worker_in_ray( + listen_address, + sock, + args, + engine_client, + logger, +) -> None: + # Modified from vllm.entrypoints.openai.api_server.run_server_worker + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: + ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) + + app = build_app(args) + + await init_app_state(engine_client, app.state, args) + + logger.info( + "Starting vLLM API server %d on %s", + engine_client.vllm_config.parallel_config._api_process_rank, + listen_address, + ) + + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + h11_max_incomplete_event_size=args.h11_max_incomplete_event_size, + h11_max_header_count=args.h11_max_header_count, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +async def run_server_in_ray(args, engine_client, logger): + # Modified from vllm.entrypoints.openai.api_server.run_server + listen_address, sock = setup_server_in_ray(args, logger) + logger.info("vLLM API server listening on %s", listen_address) + await run_server_worker_in_ray(listen_address, sock, args, engine_client, logger) + + +def dummy_add_signal_handler(self, *args, **kwargs): + # DO NOTHING HERE + pass + + +async def run_api_server_in_ray_actor_v12( + async_llm, + host: str, + port: int, + model_path: str, + logger: logging.Logger, + enable_auto_tool_choice: bool = False, + tool_call_parser: Optional[str] = None, + reasoning_parser: Optional[str] = None, + enable_log_requests: bool = False, +): + vllm_version = get_vllm_version() + if vllm_version <= parse_version("0.11.0"): + raise ValueError( + f"Unsupported vllm version: {vllm.__version__}. " + "This patch requires vllm version > 0.11.0" + ) + + parser = FlexibleArgumentParser(description="Run the OpenAI API server.") + args = make_arg_parser(parser) + cli_args = [ + "--host", + str(host), + "--port", + str(port), + "--model", + model_path, + "--enable-server-load-tracking", # enable tracking for load balancing + ] + if enable_log_requests: + cli_args.append("--enable-log-requests") + if enable_auto_tool_choice: + cli_args.append("--enable-auto-tool-choice") + if tool_call_parser: + cli_args.extend(["--tool-call-parser", tool_call_parser]) + if reasoning_parser: + cli_args.extend(["--reasoning-parser", reasoning_parser]) + args = parser.parse_args(cli_args) + if vllm_version >= parse_version("0.11.0"): + args.structured_outputs_config.reasoning_parser = reasoning_parser + logger.info(f"Starting vLLM OpenAI API server with args: {args}") + await run_server_in_ray(args, async_llm, logger) diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index ebe9d47ac3..d74a6057c8 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -10,7 +10,7 @@ from trinity.common.models.vllm_patch import get_vllm_version -def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): +def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): # noqa: C901 """Patch vLLM model runner to support prompt logprobs extraction.""" if get_vllm_version() < parse_version("0.10.2"): raise ValueError( @@ -18,7 +18,7 @@ def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): "This patch requires vllm version >= 0.10.2, <= 0.11.0." ) - def _get_prompt_logprobs_dict( + def _get_prompt_logprobs_dict_v11( self, hidden_states: torch.Tensor, num_scheduled_tokens: dict[str, int], @@ -45,7 +45,128 @@ def _get_prompt_logprobs_dict( # maintainable loop over optimal performance. completed_prefill_reqs = [] for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): - num_tokens = num_scheduled_tokens[req_id] + num_tokens = num_scheduled_tokens.get(req_id) + if num_tokens is None: + # This can happen if the request was preempted in prefill stage. + continue + + # Get metadata for this request. + request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue + + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True + ) + + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) + in_progress_dict[req_id] = logprobs_tensors + + # Determine number of logits to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc.np[req_idx].item() + prompt_hidden_states = hidden_states[offset : offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) + + # PATCH START + temp = request.sampling_params.temperature + if temp >= 1e-5: + logits.div_(temp) + # PATCH END + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids + ) + + # Transfer GPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, non_blocking=True) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, non_blocking=True) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + if prompt_logprobs_dict: + self._sync_device() + + return prompt_logprobs_dict + + def _get_prompt_logprobs_dict_v12( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, Optional[LogprobsTensors]]: + """Patched version of _get_prompt_logprobs_dict. + + This is a monkey-patched version of `_get_prompt_logprobs_dict` from + `vllm.v1.worker.gpu_model_runner.GPUModelRunner` (vLLM versions + 0.10.2 to 0.11.0). + + The original function does not apply temperature scaling to logits when + calculating prompt logprobs, which can lead to incorrect logprob values + when the temperature is not 1.0. This patch adds the missing + temperature scaling. + """ + num_prompt_logprobs_dict = self.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + num_tokens = num_scheduled_tokens.get(req_id) + if num_tokens is None: + # This can happen if the request was preempted in prefill stage. + continue # Get metadata for this request. request = self.requests[req_id] @@ -133,4 +254,11 @@ def _get_prompt_logprobs_dict( return prompt_logprobs_dict - model_runner._get_prompt_logprobs_dict = MethodType(_get_prompt_logprobs_dict, model_runner) + if get_vllm_version() < parse_version("0.12.0"): + model_runner._get_prompt_logprobs_dict = MethodType( + _get_prompt_logprobs_dict_v11, model_runner + ) + else: + model_runner._get_prompt_logprobs_dict = MethodType( + _get_prompt_logprobs_dict_v12, model_runner + ) From f67a5377016b3e804ed3789f65c57d5e332559cd Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 14:22:08 +0800 Subject: [PATCH 02/18] fix dockerfile --- scripts/docker/Dockerfile | 2 +- scripts/docker/Dockerfile.megatron | 2 +- scripts/docker/Dockerfile.uv | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/docker/Dockerfile b/scripts/docker/Dockerfile index b87558f123..103ee2c1a6 100644 --- a/scripts/docker/Dockerfile +++ b/scripts/docker/Dockerfile @@ -6,7 +6,7 @@ # docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft:latest -FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu22.04 WORKDIR /workspace diff --git a/scripts/docker/Dockerfile.megatron b/scripts/docker/Dockerfile.megatron index 7659452311..c09dbb269c 100644 --- a/scripts/docker/Dockerfile.megatron +++ b/scripts/docker/Dockerfile.megatron @@ -6,7 +6,7 @@ # docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft-megatron:latest -FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu22.04 WORKDIR /workspace diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index 8718986a9c..b3bc0be759 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -9,7 +9,7 @@ # 1. This Dockerfile uses 'uv' to create a virtual environment for better package management. If you want a simpler setup without 'uv', please refer to `scripts/docker/Dockerfile`. # 2. Make sure to use `uv pip` to install packages within the virtual environment. -FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu22.04 WORKDIR /workspace From b6b15123f50ebadd4b72a73d26e269f32e608cfb Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 14:48:58 +0800 Subject: [PATCH 03/18] fix logprobs test --- tests/common/vllm_test.py | 40 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 2f84f6e8da..69f9b50478 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -541,17 +541,17 @@ async def test_logprobs_api(self): logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8) self.assertEqual(logprobs_1.shape, logprobs_2.shape) self.assertEqual(logprobs_3.shape, logprobs_4.shape) - self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4)) - self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4)) + self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(logprobs_3, logprobs_4, rtol=0.3, atol=1e-3)) logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) - self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4)) - self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4)) - self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4)) - self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3)) logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] @@ -560,10 +560,10 @@ async def test_logprobs_api(self): self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) - self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5)) - self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5)) - self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8)) - self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8)) + self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3)) # test vllm engine logprobs with different temperature response_1 = self.model_wrapper.chat( @@ -582,17 +582,17 @@ async def test_logprobs_api(self): logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8) self.assertEqual(logprobs_1.shape, logprobs_2.shape) self.assertEqual(logprobs_3.shape, logprobs_4.shape) - self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4)) - self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4)) + self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(logprobs_3, logprobs_4, rtol=0.3, atol=1e-3)) logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) - self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4)) - self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4)) - self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4)) - self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3)) logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] @@ -601,10 +601,10 @@ async def test_logprobs_api(self): self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) - self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5)) - self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5)) - self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8)) - self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8)) + self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3)) + self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3)) + self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3)) # test openai api and vllm engine logprobs consistency await self.model_wrapper.clean_workflow_state() From 979468a513df8c8559139375a5002d6063862bea Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 14:53:06 +0800 Subject: [PATCH 04/18] fix cuda version --- scripts/docker/Dockerfile | 2 +- scripts/docker/Dockerfile.megatron | 2 +- scripts/docker/Dockerfile.uv | 2 +- tests/common/vllm_test.py | 32 ++++++++++++++++++++++-------- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/scripts/docker/Dockerfile b/scripts/docker/Dockerfile index 103ee2c1a6..11c1af0bee 100644 --- a/scripts/docker/Dockerfile +++ b/scripts/docker/Dockerfile @@ -6,7 +6,7 @@ # docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft:latest -FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04 WORKDIR /workspace diff --git a/scripts/docker/Dockerfile.megatron b/scripts/docker/Dockerfile.megatron index c09dbb269c..0f6942f275 100644 --- a/scripts/docker/Dockerfile.megatron +++ b/scripts/docker/Dockerfile.megatron @@ -6,7 +6,7 @@ # docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft-megatron:latest -FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04 WORKDIR /workspace diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index b3bc0be759..f233a5d8a5 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -9,7 +9,7 @@ # 1. This Dockerfile uses 'uv' to create a virtual environment for better package management. If you want a simpler setup without 'uv', please refer to `scripts/docker/Dockerfile`. # 2. Make sure to use `uv pip` to install packages within the virtual environment. -FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04 WORKDIR /workspace diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 69f9b50478..9572257e2e 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -560,10 +560,18 @@ async def test_logprobs_api(self): self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) - self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3)) - self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3)) - self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3)) - self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3)) + self.assertTrue( + torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3) + ) + self.assertFalse( + torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3) + ) + self.assertTrue( + torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3) + ) + self.assertFalse( + torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3) + ) # test vllm engine logprobs with different temperature response_1 = self.model_wrapper.chat( @@ -601,10 +609,18 @@ async def test_logprobs_api(self): self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) - self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3)) - self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3)) - self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3)) - self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3)) + self.assertTrue( + torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3) + ) + self.assertFalse( + torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3) + ) + self.assertTrue( + torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3) + ) + self.assertFalse( + torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3) + ) # test openai api and vllm engine logprobs consistency await self.model_wrapper.clean_workflow_state() From a391663370cb38cefb510fb385c809cff8ce71fb Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 15:04:18 +0800 Subject: [PATCH 05/18] fix comments --- trinity/common/models/model.py | 6 +++--- trinity/common/models/vllm_patch/api_patch_v12.py | 7 +------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 157e4416f2..6331c9d353 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -95,7 +95,7 @@ def __init__( engine_type: str = "vllm", enable_lora: bool = False, enable_history: bool = False, - enbale_thinking: Optional[bool] = None, + enable_thinking: Optional[bool] = None, ): """Initialize the ModelWrapper. @@ -104,7 +104,7 @@ def __init__( engine_type (str): The type of the model engine. Default to "vllm". enable_lora (bool): Whether to enable LoRA. Default to False. enable_history (bool): Whether to enable history recording. Default to False. - enbale_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models. + enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models. """ assert engine_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model @@ -114,7 +114,7 @@ def __init__( self.logger = get_logger(__name__) self.enable_lora = enable_lora self.enable_history = enable_history - self.enable_thinking = enbale_thinking + self.enable_thinking = enable_thinking self.history = [] self.status = RunningStatus.RUNNING self.workflow_state: Dict = {} diff --git a/trinity/common/models/vllm_patch/api_patch_v12.py b/trinity/common/models/vllm_patch/api_patch_v12.py index da43bfd9a7..b926b158a1 100644 --- a/trinity/common/models/vllm_patch/api_patch_v12.py +++ b/trinity/common/models/vllm_patch/api_patch_v12.py @@ -1,4 +1,4 @@ -"""Patch for vllm OpenAI API server. Only for vllm versions >=0.8.5, <=0.11.0. +"""Patch for vllm OpenAI API server. Only for vllm versions >0.11.0. 1. Mocks the `add_signal_handler` method to do nothing. 2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`. @@ -124,11 +124,6 @@ async def run_server_in_ray(args, engine_client, logger): await run_server_worker_in_ray(listen_address, sock, args, engine_client, logger) -def dummy_add_signal_handler(self, *args, **kwargs): - # DO NOTHING HERE - pass - - async def run_api_server_in_ray_actor_v12( async_llm, host: str, From 8afaa1f79bfbe90201a74b73cd96c7e985e33449 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 15:52:33 +0800 Subject: [PATCH 06/18] fix openai client --- trinity/common/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 6331c9d353..b485f6ee4c 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -317,6 +317,7 @@ def get_openai_client(self) -> openai.OpenAI: ori_create = self.openai_client.chat.completions.create def record_chat_completions(*args, **kwargs): + kwargs.pop("logprobs", True) extra_body = kwargs.get("extra_body", {}) if self.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: @@ -361,7 +362,6 @@ async def record_chat_completions(*args, **kwargs): extra_body["chat_template_kwargs"] = {} extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking extra_body["return_token_ids"] = True - # self.logger.info("args: %s, kwargs: %s, extra_body: %s", args, kwargs, extra_body) response = await ori_create(*args, extra_body=extra_body, logprobs=True, **kwargs) self.history.extend(convert_api_output_to_experience(response)) return response From ec329f78af4e6c05134062cfd1bba442749d36f4 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 15:53:17 +0800 Subject: [PATCH 07/18] fix comments --- trinity/common/models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index b485f6ee4c..ae950a38c4 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -318,7 +318,7 @@ def get_openai_client(self) -> openai.OpenAI: def record_chat_completions(*args, **kwargs): kwargs.pop("logprobs", True) - extra_body = kwargs.get("extra_body", {}) + extra_body = kwargs.pop("extra_body", {}) if self.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: extra_body["chat_template_kwargs"] = {} @@ -356,7 +356,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: async def record_chat_completions(*args, **kwargs): kwargs.pop("logprobs", True) - extra_body = kwargs.get("extra_body", {}) + extra_body = kwargs.pop("extra_body", {}) if self.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: extra_body["chat_template_kwargs"] = {} From 62cca82c8c50cde568e720f667fcf52619248788 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 16:03:35 +0800 Subject: [PATCH 08/18] fix tests --- tests/common/vllm_test.py | 4 ++-- trinity/common/models/model.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 9572257e2e..18731f361e 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -442,7 +442,7 @@ async def test_api(self): ) self.assertEqual(2, len(response.choices)) self.assertTrue(hasattr(response.choices[0], "token_ids")) - self.assertTrue(len(response.choices[0].token_ids) > 0) + self.assertTrue(response.choices[0].token_ids is None) with self.assertRaises(ValueError): self.model_wrapper_no_history.extract_experience_from_history() self.assertEqual(len(self.model_wrapper_no_history.history), 0) @@ -764,7 +764,7 @@ async def test_api_async(self): ) self.assertEqual(2, len(response.choices)) self.assertTrue(hasattr(response.choices[0], "token_ids")) - self.assertTrue(len(response.choices[0].token_ids) > 0) + self.assertTrue(response.choices[0].token_ids is None) with self.assertRaises(ValueError): self.model_wrapper_no_history.extract_experience_from_history() self.assertEqual(len(self.model_wrapper_no_history.history), 0) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index ae950a38c4..800f698e10 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -317,14 +317,14 @@ def get_openai_client(self) -> openai.OpenAI: ori_create = self.openai_client.chat.completions.create def record_chat_completions(*args, **kwargs): - kwargs.pop("logprobs", True) + logprobs = kwargs.pop("logprobs", True) extra_body = kwargs.pop("extra_body", {}) if self.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: extra_body["chat_template_kwargs"] = {} extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking extra_body["return_token_ids"] = True - response = ori_create(*args, extra_body=extra_body, logprobs=True, **kwargs) + response = ori_create(*args, extra_body=extra_body, logprobs=logprobs, **kwargs) self.history.extend(convert_api_output_to_experience(response)) return response @@ -355,14 +355,16 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: ori_create = self.openai_async_client.chat.completions.create async def record_chat_completions(*args, **kwargs): - kwargs.pop("logprobs", True) + logprobs = kwargs.pop("logprobs", True) extra_body = kwargs.pop("extra_body", {}) if self.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: extra_body["chat_template_kwargs"] = {} extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking extra_body["return_token_ids"] = True - response = await ori_create(*args, extra_body=extra_body, logprobs=True, **kwargs) + response = await ori_create( + *args, extra_body=extra_body, logprobs=logprobs, **kwargs + ) self.history.extend(convert_api_output_to_experience(response)) return response From 5df7567672e853cb79277d365e89cbaace7c2d81 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 16:16:37 +0800 Subject: [PATCH 09/18] update docker image --- .github/workflows/docker/docker-compose.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index b6756f659a..e7724b3771 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -1,6 +1,6 @@ services: trinity-node-1: - image: trinity-rft-unittest:20251030 + image: trinity-rft-unittest:20251212 pull_policy: never command: sh -c "pip install -e .[dev] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block" environment: @@ -29,7 +29,7 @@ services: capabilities: [gpu] trinity-node-2: - image: trinity-rft-unittest:20251030 + image: trinity-rft-unittest:20251212 pull_policy: never command: sh -c "pip install -e .[dev] && ray start --address=trinity-node-1:6379 --block" environment: From 46a5f83a5d70b64d5a43bc5cd6047cc752d7098b Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 16:41:32 +0800 Subject: [PATCH 10/18] fix rope paramters --- trinity/common/models/vllm_model.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index dc36bb1d00..67ff0dc2ff 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -2,6 +2,7 @@ import asyncio import os +from collections import defaultdict from typing import Any, Dict, List, Optional, Sequence import numpy as np @@ -82,11 +83,18 @@ def __init__( max_model_len = config.max_model_len self.enable_lora = config.enable_lora self.default_lora_path = config.lora_kwargs.pop("default_lora_path", None) - rope_kwargs = { - key: getattr(config, key) - for key in ["rope_scaling", "rope_theta"] - if getattr(config, key) is not None - } + if self.vllm_version >= parse_version("0.12.0"): + rope_kwargs = defaultdict() + if config.rope_scaling is not None: + rope_kwargs["rope_parameters"] = config.rope_scaling + if config.rope_theta is not None: + rope_kwargs["rope_parameters"]["rope_theta"] = config.rope_theta + else: + rope_kwargs = { + key: getattr(config, key) + for key in ["rope_scaling", "rope_theta"] + if getattr(config, key) is not None + } engine_args = vllm.AsyncEngineArgs( model=config.model_path, enforce_eager=config.enforce_eager, From ff843be6cee46fca86d5837e3130b932d1065899 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 17:05:32 +0800 Subject: [PATCH 11/18] update doc --- .../sphinx_doc/source/tutorial/trinity_installation.md | 2 +- .../source_zh/tutorial/trinity_installation.md | 2 +- trinity/common/models/vllm_model.py | 10 +++++++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_installation.md b/docs/sphinx_doc/source/tutorial/trinity_installation.md index bd72967556..1ad1b030fe 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_installation.md +++ b/docs/sphinx_doc/source/tutorial/trinity_installation.md @@ -6,7 +6,7 @@ For installing Trinity-RFT, you have three options: from source (recommended), v Before installing, ensure your system meets the following requirements: - **Python**: Version 3.10 to 3.12 (inclusive) -- **CUDA**: Version >= 12.8 +- **CUDA**: Version >= 12.8 (If using vLLM v0.11.1 or higher, CUDA >= 12.9 is required) - **GPUs**: At least 2 GPUs --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md b/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md index d337e84960..07ae9d659a 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md @@ -6,7 +6,7 @@ 在安装前,请确保您的系统满足以下要求: - **Python**:3.10 至 3.12(包含) -- **CUDA**:大于等于 12.8 +- **CUDA**:大于等于 12.8 (如果使用 vLLM v0.11.1 或更高版本,则需要 CUDA >= 12.9) - **GPU**:至少 2 块 GPU --- diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 67ff0dc2ff..27feabf35d 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -84,11 +84,15 @@ def __init__( self.enable_lora = config.enable_lora self.default_lora_path = config.lora_kwargs.pop("default_lora_path", None) if self.vllm_version >= parse_version("0.12.0"): - rope_kwargs = defaultdict() + rope_params = defaultdict(dict) if config.rope_scaling is not None: - rope_kwargs["rope_parameters"] = config.rope_scaling + rope_params["rope_parameters"] = config.rope_scaling if config.rope_theta is not None: - rope_kwargs["rope_parameters"]["rope_theta"] = config.rope_theta + rope_params["rope_parameters"]["rope_theta"] = config.rope_theta + if len(rope_params) > 0: + rope_kwargs = {"hf_overrides": rope_params} + else: + rope_kwargs = {} else: rope_kwargs = { key: getattr(config, key) From c3ff69b1902792a21a3635a368d7dcd1ed9bd75d Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 19:45:16 +0800 Subject: [PATCH 12/18] fix docker file --- docs/sphinx_doc/source/tutorial/trinity_installation.md | 2 +- docs/sphinx_doc/source_zh/tutorial/trinity_installation.md | 2 +- pyproject.toml | 2 +- scripts/docker/Dockerfile | 2 +- scripts/docker/Dockerfile.megatron | 2 +- scripts/docker/Dockerfile.uv | 4 ++-- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_installation.md b/docs/sphinx_doc/source/tutorial/trinity_installation.md index 1ad1b030fe..bd72967556 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_installation.md +++ b/docs/sphinx_doc/source/tutorial/trinity_installation.md @@ -6,7 +6,7 @@ For installing Trinity-RFT, you have three options: from source (recommended), v Before installing, ensure your system meets the following requirements: - **Python**: Version 3.10 to 3.12 (inclusive) -- **CUDA**: Version >= 12.8 (If using vLLM v0.11.1 or higher, CUDA >= 12.9 is required) +- **CUDA**: Version >= 12.8 - **GPUs**: At least 2 GPUs --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md b/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md index 07ae9d659a..d337e84960 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md @@ -6,7 +6,7 @@ 在安装前,请确保您的系统满足以下要求: - **Python**:3.10 至 3.12(包含) -- **CUDA**:大于等于 12.8 (如果使用 vLLM v0.11.1 或更高版本,则需要 CUDA >= 12.9) +- **CUDA**:大于等于 12.8 - **GPU**:至少 2 块 GPU --- diff --git a/pyproject.toml b/pyproject.toml index 64dfabd3da..f7365cfde4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ dev = [ ] megatron = [ "megatron-core[mlm]==0.13.1", - "transformer_engine[pytorch]==2.8.0", + "transformer_engine[pytorch]==2.9.0", "mbridge>=0.13.0", ] diff --git a/scripts/docker/Dockerfile b/scripts/docker/Dockerfile index 11c1af0bee..b87558f123 100644 --- a/scripts/docker/Dockerfile +++ b/scripts/docker/Dockerfile @@ -6,7 +6,7 @@ # docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft:latest -FROM nvcr.io/nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 WORKDIR /workspace diff --git a/scripts/docker/Dockerfile.megatron b/scripts/docker/Dockerfile.megatron index 0f6942f275..7659452311 100644 --- a/scripts/docker/Dockerfile.megatron +++ b/scripts/docker/Dockerfile.megatron @@ -6,7 +6,7 @@ # docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data trinity-rft-megatron:latest -FROM nvcr.io/nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 WORKDIR /workspace diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index f233a5d8a5..82d5389ada 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -9,7 +9,7 @@ # 1. This Dockerfile uses 'uv' to create a virtual environment for better package management. If you want a simpler setup without 'uv', please refer to `scripts/docker/Dockerfile`. # 2. Make sure to use `uv pip` to install packages within the virtual environment. -FROM nvcr.io/nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04 +FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 WORKDIR /workspace @@ -39,7 +39,7 @@ RUN . /opt/venv/bin/activate && \ # Install flash_attn and Megatron RUN . /opt/venv/bin/activate && \ - uv pip install flash_attn==2.8.1 --no-deps --no-cache-dir && \ + uv pip install flash_attn==2.8.1 --no-cache-dir && \ uv pip install -e .[megatron] && \ NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 \ uv pip install -v --no-build-isolation \ From a8847ddf3e78d2ea8570321beed3449484c6597b Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 12 Dec 2025 19:47:45 +0800 Subject: [PATCH 13/18] update unittest test docker image --- .github/workflows/docker/docker-compose.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index e7724b3771..c2ea17e526 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -1,6 +1,6 @@ services: trinity-node-1: - image: trinity-rft-unittest:20251212 + image: trinity-rft-unittest:20251213 pull_policy: never command: sh -c "pip install -e .[dev] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block" environment: @@ -29,7 +29,7 @@ services: capabilities: [gpu] trinity-node-2: - image: trinity-rft-unittest:20251212 + image: trinity-rft-unittest:20251213 pull_policy: never command: sh -c "pip install -e .[dev] && ray start --address=trinity-node-1:6379 --block" environment: From ee9759054181e78383c56da77d76ffdf80def58f Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 15 Dec 2025 10:22:26 +0800 Subject: [PATCH 14/18] fix api --- trinity/explorer/api/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trinity/explorer/api/api.py b/trinity/explorer/api/api.py index 66b5e2e97a..b702c39f5a 100644 --- a/trinity/explorer/api/api.py +++ b/trinity/explorer/api/api.py @@ -15,6 +15,8 @@ async def chat_completions(request: Request): # Currently, we do not support streaming chat completions body = await request.json() + if "return_token_ids" not in body: + body["return_token_ids"] = True url = await request.app.state.service.allocate_model() try: async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client: From 832d5aaa55c46c9e8f4553ba84d6866867f79b47 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 16 Dec 2025 12:30:20 +0800 Subject: [PATCH 15/18] fix logprobs patch for vllm 0.10.2 --- trinity/common/config.py | 2 +- trinity/common/models/vllm_patch/worker_patch.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index 9738ec3b8e..b2a824a2db 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -484,7 +484,7 @@ class InferenceModelConfig: use_v1: bool = True enforce_eager: bool = False enable_prefix_caching: bool = False - enable_chunked_prefill: bool = False + enable_chunked_prefill: bool = True gpu_memory_utilization: float = 0.9 dtype: str = "bfloat16" seed: int = 42 diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index d74a6057c8..c58decbf7c 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -12,11 +12,13 @@ def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): # noqa: C901 """Patch vLLM model runner to support prompt logprobs extraction.""" - if get_vllm_version() < parse_version("0.10.2"): + version = get_vllm_version() + if version < parse_version("0.10.2") or version > parse_version("0.12.0"): raise ValueError( f"Unsupported vllm version: {vllm.__version__}. " - "This patch requires vllm version >= 0.10.2, <= 0.11.0." + "This patch requires vllm version >= 0.10.2, <= 0.12.0." ) + is_v0102 = version == parse_version("0.10.2") def _get_prompt_logprobs_dict_v11( self, @@ -99,9 +101,12 @@ def _get_prompt_logprobs_dict_v11( req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc.np[req_idx].item() prompt_hidden_states = hidden_states[offset : offset + num_logits] - logits = self.model.compute_logits(prompt_hidden_states) - # PATCH START + if is_v0102: + logits = self.model.compute_logits(prompt_hidden_states, None) + else: + logits = self.model.compute_logits(prompt_hidden_states) + temp = request.sampling_params.temperature if temp >= 1e-5: logits.div_(temp) From 1aff31a2f76f18265815d24862e684cc9103993b Mon Sep 17 00:00:00 2001 From: pan-x-c Date: Tue, 16 Dec 2025 07:19:16 +0000 Subject: [PATCH 16/18] fix debug mode --- trinity/cli/launcher.py | 1 + trinity/common/models/model.py | 7 +++---- trinity/common/models/vllm_model.py | 1 + trinity/explorer/workflow_runner.py | 3 +++ 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 6e6c462c56..d9c0d95771 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -249,6 +249,7 @@ def debug( os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir load_plugins() config = load_config(config_path) + config.mode = "explore" config.check_and_update() sys.path.insert(0, os.getcwd()) config.ray_namespace = DEBUG_NAMESPACE diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 800f698e10..110db0a3b7 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -3,7 +3,7 @@ import asyncio import socket from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import httpx import numpy as np @@ -12,7 +12,6 @@ import torch from PIL import Image from torch import Tensor -from vllm.lora.request import LoRARequest from trinity.common.constants import RunningStatus from trinity.common.experience import Experience @@ -280,13 +279,13 @@ async def model_path_async(self) -> str: """Get the model path.""" return await self.model.get_model_path.remote() - def get_lora_request(self) -> Optional[LoRARequest]: + def get_lora_request(self) -> Any: if self.enable_lora: return ray.get(self.model.get_lora_request.remote()) else: return None - async def get_lora_request_async(self) -> Optional[LoRARequest]: + async def get_lora_request_async(self) -> Any: if self.enable_lora: return await self.model.get_lora_request.remote() else: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 27feabf35d..d1f1d3f61f 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -50,6 +50,7 @@ def __init__( self.logger.info(f"Using vLLM v{int(config.use_v1)} engine") os.environ["VLLM_USE_V1"] = str(int(config.use_v1)) if config.use_v1: + os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1)) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 2590dae064..57c5c25bf0 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -278,6 +278,9 @@ async def debug(self) -> None: with VizTracer(output_file=self.output_profiling_file): status, exps = await self.run_task(task, 1, 0) + if not status.ok and len(exps) == 0: + exps = self.model_wrapper.extract_experience_from_history() + self.logger.info(f"Debugging failed, extracting {len(exps)} experiences from history.") await self.sqlite_writer.write_async(exps) if status.ok: print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}") From 1392645bd498b002da5bf279f94b7eb2177a8157 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 16 Dec 2025 15:42:49 +0800 Subject: [PATCH 17/18] use old docker --- .github/workflows/docker/docker-compose.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index c2ea17e526..b6756f659a 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -1,6 +1,6 @@ services: trinity-node-1: - image: trinity-rft-unittest:20251213 + image: trinity-rft-unittest:20251030 pull_policy: never command: sh -c "pip install -e .[dev] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block" environment: @@ -29,7 +29,7 @@ services: capabilities: [gpu] trinity-node-2: - image: trinity-rft-unittest:20251213 + image: trinity-rft-unittest:20251030 pull_policy: never command: sh -c "pip install -e .[dev] && ray start --address=trinity-node-1:6379 --block" environment: From 2fe15a0262ab69326732fc3097e6616ca60c0f67 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 16 Dec 2025 16:15:58 +0800 Subject: [PATCH 18/18] fix pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f7365cfde4..f7a8162bfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.10,<3.13" dependencies = [ "verl==0.5.0", "ray[default]>=2.50.0", - "vllm>=0.10.2,<=0.12.0", + "vllm>=0.10.2,<=0.11.0", "tensordict", "wandb", "omegaconf", @@ -73,7 +73,7 @@ dev = [ ] megatron = [ "megatron-core[mlm]==0.13.1", - "transformer_engine[pytorch]==2.9.0", + "transformer_engine[pytorch]==2.8.0", "mbridge>=0.13.0", ]