Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ classifiers = [
requires-python = ">=3.10,<3.13"
dependencies = [
"verl==0.5.0",
"ray[default]>=2.48.0",
"vllm>=0.10.2,<=0.11.0",
"ray[default]>=2.50.0",
"vllm>=0.10.2,<=0.12.0",
"tensordict",
"wandb",
"omegaconf",
Expand Down
2 changes: 1 addition & 1 deletion scripts/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest


FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04
FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu22.04

WORKDIR /workspace

Expand Down
2 changes: 1 addition & 1 deletion scripts/docker/Dockerfile.megatron
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft-megatron:latest


FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04
FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu22.04

WORKDIR /workspace

Expand Down
9 changes: 4 additions & 5 deletions scripts/docker/Dockerfile.uv
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# 1. This Dockerfile uses 'uv' to create a virtual environment for better package management. If you want a simpler setup without 'uv', please refer to `scripts/docker/Dockerfile`.
# 2. Make sure to use `uv pip` to install packages within the virtual environment.

FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04
FROM nvcr.io/nvidia/cuda:13.0.2-cudnn-devel-ubuntu22.04

WORKDIR /workspace

Expand All @@ -22,15 +22,14 @@ RUN chmod 1777 /tmp && apt update && apt install -y \
&& ln -sf /usr/bin/python3 /usr/bin/python \
&& ln -sf /usr/bin/pip3 /usr/bin/pip

# For Aliyun users: update pip mirror to aliyun to speed up pip install
# ENV PIP_INDEX_URL=http://mirrors.cloud.aliyuncs.com/pypi/simple/
# ENV PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com

ENV VIRTUAL_ENV=/opt/venv

# copy the Trinity-RFT dir into the workspace
COPY . .

# For Aliyun users: update pip mirror to aliyun to speed up pip install
# ENV UV_DEFAULT_INDEX=http://mirrors.cloud.aliyuncs.com/pypi/simple/

# Install uv
RUN pip install uv && uv venv /opt/venv --python=python3.12

Expand Down
1 change: 1 addition & 0 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def setUp(self):
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True
self.config.explorer.rollout_model.enable_log_requests = True

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
Expand Down
36 changes: 31 additions & 5 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import socket
from abc import ABC, abstractmethod
from functools import partial
from typing import Dict, List, Optional, Sequence, Tuple, Union

import httpx
Expand Down Expand Up @@ -96,7 +95,17 @@ def __init__(
engine_type: str = "vllm",
enable_lora: bool = False,
enable_history: bool = False,
enbale_thinking: Optional[bool] = None,
):
"""Initialize the ModelWrapper.

Args:
model (InferenceModel): The inference model Ray actor.
engine_type (str): The type of the model engine. Default to "vllm".
enable_lora (bool): Whether to enable LoRA. Default to False.
enable_history (bool): Whether to enable history recording. Default to False.
enbale_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models.
"""
assert engine_type.startswith("vllm"), "Only vLLM model is supported for now."
self.model = model
self.api_address: str = None
Expand All @@ -105,6 +114,7 @@ def __init__(
self.logger = get_logger(__name__)
self.enable_lora = enable_lora
self.enable_history = enable_history
self.enable_thinking = enbale_thinking
self.history = []
self.status = RunningStatus.RUNNING
self.workflow_state: Dict = {}
Expand Down Expand Up @@ -303,10 +313,17 @@ def get_openai_client(self) -> openai.OpenAI:
)
if self.enable_history:
# add a decorator to the openai client to record history
ori_create = partial(self.openai_client.chat.completions.create, logprobs=True)

ori_create = self.openai_client.chat.completions.create

def record_chat_completions(*args, **kwargs):
response = ori_create(*args, **kwargs)
extra_body = kwargs.get("extra_body", {})
if self.enable_thinking is not None:
if "chat_template_kwargs" not in extra_body:
extra_body["chat_template_kwargs"] = {}
extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
extra_body["return_token_ids"] = True
response = ori_create(*args, extra_body=extra_body, logprobs=True, **kwargs)
self.history.extend(convert_api_output_to_experience(response))
return response

Expand All @@ -333,10 +350,19 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
)
if self.enable_history:
# add a decorator to the openai client to record history
ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True)

ori_create = self.openai_async_client.chat.completions.create

async def record_chat_completions(*args, **kwargs):
response = await ori_create(*args, **kwargs)
kwargs.pop("logprobs", True)
extra_body = kwargs.get("extra_body", {})
if self.enable_thinking is not None:
if "chat_template_kwargs" not in extra_body:
extra_body["chat_template_kwargs"] = {}
extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
extra_body["return_token_ids"] = True
# self.logger.info("args: %s, kwargs: %s, extra_body: %s", args, kwargs, extra_body)
response = await ori_create(*args, extra_body=extra_body, logprobs=True, **kwargs)
self.history.extend(convert_api_output_to_experience(response))
return response

Expand Down
86 changes: 51 additions & 35 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
import numpy as np
import ray
import torch
import vllm
from packaging.version import parse as parse_version
from PIL import Image
from transformers import AutoProcessor
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind

from trinity.common.config import InferenceModelConfig
from trinity.common.experience import Experience
Expand All @@ -22,7 +19,7 @@
)
from trinity.common.models.model import InferenceModel
from trinity.common.models.utils import get_action_mask_method
from trinity.common.models.vllm_patch.api_patch import get_vllm_version
from trinity.common.models.vllm_patch import get_vllm_version
from trinity.utils.log import get_logger


Expand All @@ -38,20 +35,24 @@ def __init__(
self,
config: InferenceModelConfig,
) -> None:
import vllm
from vllm.sampling_params import RequestOutputKind

self.logger = get_logger(__name__)
self.vllm_version = get_vllm_version()
self.config = config
self.use_v1 = config.use_v1
if config.tensor_parallel_size != 1:
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_RAY_BUNDLE_INDICES"] = config.bundle_indices
if not vllm.envs.is_set("VLLM_USE_V1"):
if self.vllm_version <= parse_version("0.11.0") and not vllm.envs.is_set("VLLM_USE_V1"):
self.logger.info(f"Using vLLM v{int(config.use_v1)} engine")
os.environ["VLLM_USE_V1"] = str(int(config.use_v1))
if config.use_v1:
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.use_v1))
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
if get_vllm_version() >= parse_version("0.11.0"):
if self.vllm_version >= parse_version("0.11.0"):
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
if not config.enforce_eager:
# To avoid torch compile conflicts when multiple model are started simultaneously.
Expand Down Expand Up @@ -99,7 +100,6 @@ def __init__(
trust_remote_code=True,
task="generate",
gpu_memory_utilization=config.gpu_memory_utilization,
enable_chunked_prefill=config.enable_chunked_prefill,
# max_num_batched_tokens=256, # you can further set this parameter to reduce the vllm peak memory usage
override_generation_config={ # TODO: find a way to unittest this
"temperature": config.temperature,
Expand All @@ -114,12 +114,13 @@ def __init__(
**rope_kwargs,
**config.lora_kwargs,
)
if get_vllm_version() > parse_version("0.10.0"):
if self.vllm_version > parse_version("0.10.0"):
engine_args.enable_log_requests = config.enable_log_requests
else:
engine_args.disable_log_requests = not config.enable_log_requests
if get_vllm_version() >= parse_version("0.11.0"):
if self.vllm_version >= parse_version("0.11.0"):
engine_args.reasoning_parser = config.reasoning_parser

self.async_llm = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.processor = None
self.tokenizer = None
Expand Down Expand Up @@ -157,9 +158,7 @@ async def prepare(
await self.run_api_server()
self._prepared = True

async def chat(
self, messages: List[Dict], lora_request: LoRARequest = None, **kwargs
) -> Sequence[Experience]:
async def chat(self, messages: List[Dict], lora_request=None, **kwargs) -> Sequence[Experience]:
"""Chat with the model with a list of messages in async.

Args:
Expand Down Expand Up @@ -190,9 +189,7 @@ async def chat(
)
return await self.generate(prompt=prompt, lora_request=lora_request, **kwargs)

async def generate(
self, prompt: str, lora_request: LoRARequest = None, **kwargs
) -> Sequence[Experience]:
async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[Experience]:
"""Generate a response from the provided prompt in async.

Args:
Expand Down Expand Up @@ -361,7 +358,7 @@ async def generate_mm(
async def logprobs( # type: ignore [override]
self,
token_ids: List[int],
lora_request: LoRARequest = None,
lora_request=None,
temperature: Optional[float] = None,
) -> torch.Tensor:
"""Calculate the logprobs of the given tokens in async. Please slice the result carefully
Expand Down Expand Up @@ -392,9 +389,7 @@ async def logprobs( # type: ignore [override]
dtype=torch.float32,
)

async def _generate_internal(
self, prompt: Any, lora_request: LoRARequest = None, **kwargs
) -> Any:
async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any:
# Send the request to the LLM engine.
self.request_id += 1
stream = self.async_llm.generate(
Expand Down Expand Up @@ -561,23 +556,42 @@ async def run_api_server(self) -> bool:
self.logger.info("OpenAI API server is already running. Skipping...")
return True # already running

from trinity.common.models.vllm_patch.api_patch import (
run_api_server_in_ray_actor,
)

api_server_host, api_server_port = self.get_available_address()
self.api_server = asyncio.create_task(
run_api_server_in_ray_actor(
self.async_llm,
api_server_host,
api_server_port,
self.config.model_path, # type: ignore [arg-type]
self.config.enable_auto_tool_choice,
self.config.tool_call_parser,
self.config.reasoning_parser,
self.config.enable_log_requests,
if self.vllm_version <= parse_version("0.11.0"):
from trinity.common.models.vllm_patch.api_patch import (
run_api_server_in_ray_actor,
)

self.api_server = asyncio.create_task(
run_api_server_in_ray_actor(
self.async_llm,
api_server_host,
api_server_port,
self.config.model_path, # type: ignore [arg-type]
self.config.enable_auto_tool_choice,
self.config.tool_call_parser,
self.config.reasoning_parser,
self.config.enable_log_requests,
)
)
else:
from trinity.common.models.vllm_patch.api_patch_v12 import (
run_api_server_in_ray_actor_v12,
)

self.api_server = asyncio.create_task(
run_api_server_in_ray_actor_v12(
self.async_llm,
api_server_host,
api_server_port,
self.config.model_path, # type: ignore [arg-type]
logger=self.logger,
enable_auto_tool_choice=self.config.enable_auto_tool_choice,
tool_call_parser=self.config.tool_call_parser,
reasoning_parser=self.config.reasoning_parser,
enable_log_requests=self.config.enable_log_requests,
)
)
)
self.api_server_host = api_server_host
self.api_server_port = api_server_port
return True
Expand All @@ -604,7 +618,9 @@ def get_model_version(self) -> int:
def get_model_path(self) -> str:
return self.config.model_path # type: ignore [return-value]

def get_lora_request(self, lora_path: Optional[str] = None) -> LoRARequest:
def get_lora_request(self, lora_path: Optional[str] = None) -> Any:
from vllm.lora.request import LoRARequest

assert self.config.lora_modules is not None
lora_request = LoRARequest(**self.config.lora_modules[0])
if lora_path is not None:
Expand Down
9 changes: 6 additions & 3 deletions trinity/common/models/vllm_patch/api_patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Patch for vllm OpenAI API server.
"""Patch for vllm OpenAI API server. Only for vllm versions >=0.8.5, <=0.11.0.

1. Mocks the `add_signal_handler` method to do nothing.
2. Adds `token_ids` and `prompt_token_ids` to the `ChatCompletionResponse`.
Expand Down Expand Up @@ -51,7 +51,6 @@ class PatchedChatCompletionResponse(ChatCompletionResponse):
choices: list[PatchedChatCompletionResponseChoice] = list[ChatCompletionResponseChoice]


# TODO: add patch to stream generator
async def chat_completion_full_generator( # noqa C901
self,
request,
Expand Down Expand Up @@ -304,7 +303,11 @@ async def patch_and_serve_http(app, sock, args):
loop = asyncio.get_event_loop()
original_add_signal_handler = loop.add_signal_handler
loop.add_signal_handler = functools.partial(dummy_add_signal_handler, loop)
OpenAIServingChat.chat_completion_full_generator = chat_completion_full_generator
vllm_version = get_vllm_version()

# from 0.10.2, vllm added token_ids to ChatCompletionResponseChoice, so no need to patch
if vllm_version < parse_version("0.10.2"):
OpenAIServingChat.chat_completion_full_generator = chat_completion_full_generator

try:
shutdown_task = await serve_http(
Expand Down
Loading