Skip to content

Commit 8be19eb

Browse files
authored
Support vLLM v0.12.0 (#438)
1 parent a171560 commit 8be19eb

File tree

12 files changed

+469
-87
lines changed

12 files changed

+469
-87
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
requires-python = ">=3.10,<3.13"
2323
dependencies = [
2424
"verl==0.5.0",
25-
"ray[default]>=2.48.0",
25+
"ray[default]>=2.50.0",
2626
"vllm>=0.10.2,<=0.11.0",
2727
"tensordict",
2828
"wandb",

scripts/docker/Dockerfile.uv

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@ RUN chmod 1777 /tmp && apt update && apt install -y \
2222
&& ln -sf /usr/bin/python3 /usr/bin/python \
2323
&& ln -sf /usr/bin/pip3 /usr/bin/pip
2424

25-
# For Aliyun users: update pip mirror to aliyun to speed up pip install
26-
# ENV PIP_INDEX_URL=http://mirrors.cloud.aliyuncs.com/pypi/simple/
27-
# ENV PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com
28-
2925
ENV VIRTUAL_ENV=/opt/venv
3026

3127
# copy the Trinity-RFT dir into the workspace
3228
COPY . .
3329

30+
# For Aliyun users: update pip mirror to aliyun to speed up pip install
31+
# ENV UV_DEFAULT_INDEX=http://mirrors.cloud.aliyuncs.com/pypi/simple/
32+
3433
# Install uv
3534
RUN pip install uv && uv venv /opt/venv --python=python3.12
3635

@@ -40,7 +39,7 @@ RUN . /opt/venv/bin/activate && \
4039

4140
# Install flash_attn and Megatron
4241
RUN . /opt/venv/bin/activate && \
43-
uv pip install flash_attn==2.8.1 --no-deps --no-cache-dir && \
42+
uv pip install flash_attn==2.8.1 --no-cache-dir && \
4443
uv pip install -e .[megatron] && \
4544
NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 \
4645
uv pip install -v --no-build-isolation \

tests/common/vllm_test.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ async def test_api(self):
442442
)
443443
self.assertEqual(2, len(response.choices))
444444
self.assertTrue(hasattr(response.choices[0], "token_ids"))
445-
self.assertTrue(len(response.choices[0].token_ids) > 0)
445+
self.assertTrue(response.choices[0].token_ids is None)
446446
with self.assertRaises(ValueError):
447447
self.model_wrapper_no_history.extract_experience_from_history()
448448
self.assertEqual(len(self.model_wrapper_no_history.history), 0)
@@ -496,6 +496,7 @@ def setUp(self):
496496
self.config.explorer.rollout_model.tensor_parallel_size = 1
497497
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
498498
self.config.explorer.rollout_model.enable_openai_api = True
499+
self.config.explorer.rollout_model.enable_log_requests = True
499500

500501
self.config.check_and_update()
501502
self.engines, self.auxiliary_engines = create_inference_models(self.config)
@@ -540,17 +541,17 @@ async def test_logprobs_api(self):
540541
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8)
541542
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
542543
self.assertEqual(logprobs_3.shape, logprobs_4.shape)
543-
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4))
544-
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4))
544+
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.3, atol=1e-3))
545+
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, rtol=0.3, atol=1e-3))
545546
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
546547
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
547548
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
548549
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
549550
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
550-
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4))
551-
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4))
552-
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4))
553-
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4))
551+
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.3, atol=1e-3))
552+
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3))
553+
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.3, atol=1e-3))
554+
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3))
554555
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
555556
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
556557
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
@@ -559,10 +560,18 @@ async def test_logprobs_api(self):
559560
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
560561
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
561562
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
562-
self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5))
563-
self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5))
564-
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
565-
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))
563+
self.assertTrue(
564+
torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3)
565+
)
566+
self.assertFalse(
567+
torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3)
568+
)
569+
self.assertTrue(
570+
torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3)
571+
)
572+
self.assertFalse(
573+
torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3)
574+
)
566575

567576
# test vllm engine logprobs with different temperature
568577
response_1 = self.model_wrapper.chat(
@@ -581,17 +590,17 @@ async def test_logprobs_api(self):
581590
logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8)
582591
self.assertEqual(logprobs_1.shape, logprobs_2.shape)
583592
self.assertEqual(logprobs_3.shape, logprobs_4.shape)
584-
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4))
585-
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4))
593+
self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.3, atol=1e-3))
594+
self.assertFalse(torch.allclose(logprobs_3, logprobs_4, rtol=0.3, atol=1e-3))
586595
logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1]
587596
logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1]
588597
logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1]
589598
logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1]
590599
self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape)
591-
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4))
592-
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4))
593-
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4))
594-
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4))
600+
self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.3, atol=1e-3))
601+
self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3))
602+
self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.3, atol=1e-3))
603+
self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.3, atol=1e-3))
595604
logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :]
596605
logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :]
597606
logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :]
@@ -600,10 +609,18 @@ async def test_logprobs_api(self):
600609
self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape)
601610
self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape)
602611
self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape)
603-
self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5))
604-
self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5))
605-
self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8))
606-
self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8))
612+
self.assertTrue(
613+
torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.3, atol=1e-3)
614+
)
615+
self.assertFalse(
616+
torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3)
617+
)
618+
self.assertTrue(
619+
torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3)
620+
)
621+
self.assertFalse(
622+
torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3)
623+
)
607624

608625
# test openai api and vllm engine logprobs consistency
609626
await self.model_wrapper.clean_workflow_state()
@@ -747,7 +764,7 @@ async def test_api_async(self):
747764
)
748765
self.assertEqual(2, len(response.choices))
749766
self.assertTrue(hasattr(response.choices[0], "token_ids"))
750-
self.assertTrue(len(response.choices[0].token_ids) > 0)
767+
self.assertTrue(response.choices[0].token_ids is None)
751768
with self.assertRaises(ValueError):
752769
self.model_wrapper_no_history.extract_experience_from_history()
753770
self.assertEqual(len(self.model_wrapper_no_history.history), 0)

trinity/cli/launcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def debug(
249249
os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir
250250
load_plugins()
251251
config = load_config(config_path)
252+
config.mode = "explore"
252253
config.check_and_update()
253254
sys.path.insert(0, os.getcwd())
254255
config.ray_namespace = DEBUG_NAMESPACE

trinity/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ class InferenceModelConfig:
484484
use_v1: bool = True
485485
enforce_eager: bool = False
486486
enable_prefix_caching: bool = False
487-
enable_chunked_prefill: bool = False
487+
enable_chunked_prefill: bool = True
488488
gpu_memory_utilization: float = 0.9
489489
dtype: str = "bfloat16"
490490
seed: int = 42

trinity/common/models/model.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import asyncio
44
import socket
55
from abc import ABC, abstractmethod
6-
from functools import partial
7-
from typing import Dict, List, Optional, Sequence, Tuple, Union
6+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
87

98
import httpx
109
import numpy as np
@@ -13,7 +12,6 @@
1312
import torch
1413
from PIL import Image
1514
from torch import Tensor
16-
from vllm.lora.request import LoRARequest
1715

1816
from trinity.common.constants import RunningStatus
1917
from trinity.common.experience import Experience
@@ -96,7 +94,17 @@ def __init__(
9694
engine_type: str = "vllm",
9795
enable_lora: bool = False,
9896
enable_history: bool = False,
97+
enable_thinking: Optional[bool] = None,
9998
):
99+
"""Initialize the ModelWrapper.
100+
101+
Args:
102+
model (InferenceModel): The inference model Ray actor.
103+
engine_type (str): The type of the model engine. Default to "vllm".
104+
enable_lora (bool): Whether to enable LoRA. Default to False.
105+
enable_history (bool): Whether to enable history recording. Default to False.
106+
enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models.
107+
"""
100108
assert engine_type.startswith("vllm"), "Only vLLM model is supported for now."
101109
self.model = model
102110
self.api_address: str = None
@@ -105,6 +113,7 @@ def __init__(
105113
self.logger = get_logger(__name__)
106114
self.enable_lora = enable_lora
107115
self.enable_history = enable_history
116+
self.enable_thinking = enable_thinking
108117
self.history = []
109118
self.status = RunningStatus.RUNNING
110119
self.workflow_state: Dict = {}
@@ -270,13 +279,13 @@ async def model_path_async(self) -> str:
270279
"""Get the model path."""
271280
return await self.model.get_model_path.remote()
272281

273-
def get_lora_request(self) -> Optional[LoRARequest]:
282+
def get_lora_request(self) -> Any:
274283
if self.enable_lora:
275284
return ray.get(self.model.get_lora_request.remote())
276285
else:
277286
return None
278287

279-
async def get_lora_request_async(self) -> Optional[LoRARequest]:
288+
async def get_lora_request_async(self) -> Any:
280289
if self.enable_lora:
281290
return await self.model.get_lora_request.remote()
282291
else:
@@ -303,10 +312,18 @@ def get_openai_client(self) -> openai.OpenAI:
303312
)
304313
if self.enable_history:
305314
# add a decorator to the openai client to record history
306-
ori_create = partial(self.openai_client.chat.completions.create, logprobs=True)
315+
316+
ori_create = self.openai_client.chat.completions.create
307317

308318
def record_chat_completions(*args, **kwargs):
309-
response = ori_create(*args, **kwargs)
319+
logprobs = kwargs.pop("logprobs", True)
320+
extra_body = kwargs.pop("extra_body", {})
321+
if self.enable_thinking is not None:
322+
if "chat_template_kwargs" not in extra_body:
323+
extra_body["chat_template_kwargs"] = {}
324+
extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
325+
extra_body["return_token_ids"] = True
326+
response = ori_create(*args, extra_body=extra_body, logprobs=logprobs, **kwargs)
310327
self.history.extend(convert_api_output_to_experience(response))
311328
return response
312329

@@ -333,10 +350,20 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
333350
)
334351
if self.enable_history:
335352
# add a decorator to the openai client to record history
336-
ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True)
353+
354+
ori_create = self.openai_async_client.chat.completions.create
337355

338356
async def record_chat_completions(*args, **kwargs):
339-
response = await ori_create(*args, **kwargs)
357+
logprobs = kwargs.pop("logprobs", True)
358+
extra_body = kwargs.pop("extra_body", {})
359+
if self.enable_thinking is not None:
360+
if "chat_template_kwargs" not in extra_body:
361+
extra_body["chat_template_kwargs"] = {}
362+
extra_body["chat_template_kwargs"]["enable_thinking"] = self.enable_thinking
363+
extra_body["return_token_ids"] = True
364+
response = await ori_create(
365+
*args, extra_body=extra_body, logprobs=logprobs, **kwargs
366+
)
340367
self.history.extend(convert_api_output_to_experience(response))
341368
return response
342369

0 commit comments

Comments
 (0)