Skip to content

Commit 02c7c8e

Browse files
authored
Support vLLM v0.16.0 (#510)
1 parent 8a1c1c0 commit 02c7c8e

File tree

8 files changed

+143
-9
lines changed

8 files changed

+143
-9
lines changed
4.19 KB
Binary file not shown.

docs/sphinx_doc/source/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def get_recent_tags(n: int) -> list:
8585

8686
html_logo = "../_static/logo.svg"
8787

88+
html_favicon = "../_static/favicon.ico"
89+
8890
html_theme_options = {
8991
"navigation_depth": 3,
9092
"article_header_end": "article_header_customized.html",

docs/sphinx_doc/source/tutorial/trinity_installation.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ uv sync --extra vllm --extra dev --extra flash_attn
8282
# uv sync --extra tinker --extra dev
8383
```
8484

85+
```{tip}
86+
If you can't install flash-attn due to network error or compiler error, you can try to install it from our pre-compiled wheel:
87+
88+
`python scripts/install/install_flash_attn.py`
89+
90+
If you are using `uv`, add `--uv` flag to the command above.
91+
```
92+
8593
---
8694

8795
## Using Docker

docs/sphinx_doc/source_zh/tutorial/trinity_installation.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ uv sync --extra vllm --extra dev --extra flash_attn
8282
# uv sync --extra tinker --extra dev
8383
```
8484

85+
```{tip}
86+
如果安装 flash-attn 时遇到网络错误或编译错误,您可以尝试从我们预编译的 wheel 安装:
87+
88+
`python scripts/install/install_flash_attn.py`
89+
90+
如果您使用 `uv`,请在上述命令后添加 `--uv` 参数。
91+
```
92+
8593
---
8694

8795
## 使用 Docker

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ trinity = "trinity.cli.launcher:main"
5252

5353
[project.optional-dependencies]
5454
vllm = [
55-
"vllm>=0.10.2,<=0.15.1,!=0.11.0,!=0.12.0",
55+
"vllm>=0.10.2,<=0.16.0,!=0.11.0,!=0.12.0",
5656
# v0.11 has bug when prefix-caching is enabled so we exclude it
5757
# v0.12 has a huge performance regression so we exclude it
58-
# v0.10.2 is the most stable version, but we allow up to 0.15.1 for new features
58+
# v0.10.2 is the most stable version, but we allow up to 0.16.0 for new features
5959
]
6060
data = [
6161
"py-data-juicer>=1.4.3"
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""This script is used to install flash-attn from a pre-built wheel hosted on an OSS bucket.
2+
Useful for mainland China users who have difficulty installing flash-attn from PyPI due to network issues.
3+
"""
4+
import os
5+
import platform
6+
import subprocess
7+
import sys
8+
import tempfile
9+
10+
import torch
11+
import typer
12+
13+
app = typer.Typer()
14+
FLASH_VERSION = "2.8.1"
15+
16+
17+
def check_flash_attn_installed():
18+
try:
19+
import flash_attn
20+
21+
print(f"flash_attn version: {flash_attn.__version__}")
22+
return True
23+
except ImportError:
24+
return False
25+
26+
27+
def install_flash_attn(uv: bool = False, keep_wheel: bool = False):
28+
# Get torch version
29+
TORCH_VERSION_RAW = torch.__version__
30+
torch_major, torch_minor = TORCH_VERSION_RAW.split(".")[:2]
31+
torch_version = f"{torch_major}.{torch_minor}"
32+
33+
# Get python version
34+
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
35+
36+
# Get platform name
37+
platform_name = platform.system().lower() + "_" + platform.machine()
38+
39+
# Get cxx11_abi
40+
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
41+
42+
# Is ROCM
43+
# torch.version.hip/cuda are runtime attributes not in type stubs
44+
IS_ROCM = hasattr(torch.version, "hip") and torch.version.hip is not None # type: ignore[attr-defined]
45+
46+
if IS_ROCM:
47+
print("We currently do not host ROCm wheels for flash-attn.")
48+
sys.exit(1)
49+
else:
50+
torch_cuda_version = torch.version.cuda # type: ignore[attr-defined]
51+
cuda_major = torch_cuda_version.split(".")[0] if torch_cuda_version else None
52+
if cuda_major != "12":
53+
print("Only CUDA 12 wheels are hosted for flash-attn.")
54+
sys.exit(1)
55+
cuda_version = "12"
56+
wheel_filename = (
57+
f"flash_attn-{FLASH_VERSION}%2Bcu{cuda_version}torch{torch_version}"
58+
f"cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
59+
)
60+
local_filename = (
61+
f"flash_attn-{FLASH_VERSION}-{python_version}-{python_version}-{platform_name}.whl"
62+
)
63+
64+
wheel_url = (
65+
"https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com"
66+
f"/AgentScope/download/flash-attn/{FLASH_VERSION}/{wheel_filename}"
67+
)
68+
69+
print(f"wheel_url: {wheel_url}")
70+
print(f"target_local_file: {local_filename}")
71+
72+
def _install_helper(local_path: str):
73+
subprocess.run(["wget", wheel_url, "-O", local_path], check=True)
74+
install_cmd = (
75+
["uv", "pip", "install", local_path]
76+
if uv
77+
else [sys.executable, "-m", "pip", "install", local_path]
78+
)
79+
subprocess.run(install_cmd, check=True)
80+
81+
if keep_wheel:
82+
local_path = os.path.abspath(local_filename)
83+
_install_helper(local_path)
84+
else:
85+
with tempfile.TemporaryDirectory() as tempdir:
86+
local_path = os.path.join(tempdir, local_filename)
87+
_install_helper(local_path)
88+
89+
# Try to import flash_attn
90+
if not check_flash_attn_installed():
91+
print("Failed to install flash_attn.")
92+
sys.exit(1)
93+
94+
95+
@app.command()
96+
def main(
97+
uv: bool = typer.Option(False, help="Use uv pip to install instead of pip"),
98+
keep_wheel: bool = typer.Option(
99+
False, help="Keep the downloaded wheel file in current directory"
100+
),
101+
):
102+
"""Install flash-attn from a pre-built wheel."""
103+
if check_flash_attn_installed():
104+
print("flash_attn is already installed. Skipping installation.")
105+
return
106+
install_flash_attn(uv=uv, keep_wheel=keep_wheel)
107+
108+
109+
if __name__ == "__main__":
110+
typer.run(main)

tests/common/vllm_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ async def test_api_tool_calls(self):
10331033
print_debug(f" > Finish Reason: {choice.finish_reason}")
10341034
self.assertEqual(choice.finish_reason, "tool_calls")
10351035
if self.enable_thinking:
1036-
self.assertIsNotNone(choice.message.reasoning_content)
1036+
self.assertIsNotNone(choice.message.reasoning)
10371037
self.assertIsNotNone(choice.message.tool_calls)
10381038
self.assertEqual(len(choice.message.tool_calls), 1)
10391039

trinity/common/models/vllm_patch/worker_patch.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): # noqa: C901
1414
"""Patch vLLM model runner to support prompt logprobs extraction."""
1515
version = get_vllm_version()
16-
if version < parse_version("0.10.2") or version > parse_version("0.15.1"):
16+
if version < parse_version("0.10.2") or version > parse_version("0.16.0"):
1717
raise ValueError(
1818
f"Unsupported vllm version: {vllm.__version__}. "
19-
"This patch requires vllm version >= 0.10.2, <= 0.15.1."
19+
"This patch requires vllm version >= 0.10.2, <= 0.16.0."
2020
)
2121
is_v0102 = version == parse_version("0.10.2")
2222

@@ -237,15 +237,21 @@ def _get_prompt_logprobs_dict_v12(
237237

238238
# Compute prompt logprobs.
239239
logprobs = self.sampler.compute_logprobs(logits)
240-
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
240+
logprob_tensors = self.sampler.gather_logprobs(
241241
logprobs, num_prompt_logprobs, tgt_token_ids
242242
)
243243

244244
# Transfer GPU->CPU async.
245245
chunk_slice = slice(start_idx, start_idx + num_logits)
246-
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, non_blocking=True)
247-
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True)
248-
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, non_blocking=True)
246+
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
247+
logprob_tensors.logprob_token_ids, non_blocking=True
248+
)
249+
logprobs_tensors.logprobs[chunk_slice].copy_(
250+
logprob_tensors.logprobs, non_blocking=True
251+
)
252+
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
253+
logprob_tensors.selected_token_ranks, non_blocking=True
254+
)
249255

250256
# Remove requests that have completed prefill from the batch
251257
# num_prompt_logprobs_dict.

0 commit comments

Comments
 (0)