Skip to content

Commit 8f635b6

Browse files
authored
Update vLLM version support to include 0.14.0 and 0.14.1 (#5214)
1 parent d24ec77 commit 8f635b6

File tree

6 files changed

+75
-33
lines changed

6 files changed

+75
-33
lines changed

docs/source/vllm_integration.md

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood.
44

55
> [!WARNING]
6-
> TRL currently only supports vLLM versions from `0.10.2` to `0.13.0`. Please ensure you have a version in this range installed to avoid compatibility issues.
6+
> TRL currently only supports vLLM versions from `0.10.2` to `0.14.1`. Please ensure you have a version in this range installed to avoid compatibility issues.
77
88
> [!TIP]
99
> The following trainers currently support generation with vLLM:
@@ -31,12 +31,12 @@ pip install "trl[vllm]"
3131
Then run the server on specific GPUs (e.g., GPUs 0-3):
3232

3333
```sh
34-
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
34+
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 4
3535
```
3636

3737
Once the server is running, you can use it to generate completions for training. In the example below, we are using the different supported trainers using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.
3838

39-
In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows:
39+
In this example, we shard one model across 4 GPUs with tensor parallelism. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows:
4040

4141
Sample of a simple `train.py` script:
4242

@@ -166,19 +166,15 @@ If you've ever done autoregressive decoder training, you know all the input toke
166166
When you run for example
167167

168168
```sh
169-
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4
169+
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 4
170170
```
171171

172-
the following happens:
172+
1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (4 × 1).
173+
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation.
173174

174-
![vllm](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/vllm-doc.png)
175+
2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled.
175176

176-
1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4).
177-
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load.
178-
179-
2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas.
180-
181-
3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts).
177+
3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--tensor-parallel-size=4`, each GPU participates in serving the full model).
182178
This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself.
183179
Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.
184180

@@ -224,7 +220,9 @@ options:
224220
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
225221
Number of tensor parallel workers to use. (default: 1)
226222
--data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE
227-
Number of data parallel workers to use. (default: 1)
223+
Number of data parallel workers to use. For dense models, keep this at 1. Starting from vLLM `0.14.0`, setting
224+
this above `1` for dense models is no longer supported/useful and will error out (see vLLM PR #30739).
225+
(default: 1)
228226
--host HOST Host address to run the server on. (default: 0.0.0.0)
229227
--port PORT Port to run the server on. (default: 8000)
230228
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
@@ -259,20 +257,8 @@ options:
259257
![tp dp throughput 8 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png)
260258
![tp dp throughput 4 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png)
261259

262-
First and foremost, always remember that the optimal setup depends on:
263-
264-
- The model size
265-
- The number of GPUs you have
266-
- The GPU memory size
267-
- The batch size you are using
268-
- The number of requests you are sending to the server (prompts)
269-
- The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
270-
- The number of completions you are generating for each request (`num_generations`)
271-
272-
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
273-
274-
- For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
275-
- For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
260+
> [!WARNING]
261+
> The benchmark plots above were collected with older vLLM versions. Starting with [vLLM PR #30739](https://github.com/vllm-project/vllm/pull/30739) (released in `0.14.0`), offline data parallel scaling for non-MoE (dense) models is no longer supported. To follow the latest recommendations, do not scale DP for non-MoE models.
276262
277263
### vLLM with Transformers Backend
278264

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ test = [
8383
"pytest"
8484
]
8585
vllm = [
86-
"vllm>=0.10.2,<0.14.0",
86+
"vllm>=0.10.2,<=0.14.1",
8787
"fastapi",
8888
"pydantic",
8989
"requests",

tests/test_vllm_client_server.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from types import SimpleNamespace
1818

1919
import pytest
20+
from packaging.version import Version
2021
from transformers import AutoModelForCausalLM, AutoTokenizer
2122
from transformers.testing_utils import torch_device
2223

@@ -35,8 +36,13 @@
3536

3637

3738
if is_vllm_available():
39+
import vllm
3840
from vllm import LLM, SamplingParams
3941

42+
_is_vllm_ge_014 = Version(vllm.__version__) >= Version("0.14.0")
43+
else:
44+
_is_vllm_ge_014 = False
45+
4046

4147
class TestChunkList(TrlTestCase):
4248
def test_even_split(self):
@@ -530,6 +536,26 @@ def multiply(a: int, b: int) -> int:
530536
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
531537
assert "Multiplies two integers." in decoded_prompt
532538

539+
def test_generate_with_params(self):
540+
prompts = ["Hello, AI!", "Tell me a joke"]
541+
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
542+
"completion_ids"
543+
]
544+
545+
# Check that the output is a list
546+
assert isinstance(completion_ids, list)
547+
548+
# Check that the number of generated sequences is 2 times the number of prompts
549+
assert len(completion_ids) == 2 * len(prompts)
550+
551+
# Check that the generated sequences are lists of integers
552+
for seq in completion_ids:
553+
assert all(isinstance(tok, int) for tok in seq)
554+
555+
# Check that the length of the generated sequences is less than or equal to 32
556+
for seq in completion_ids:
557+
assert len(seq) <= 32
558+
533559
def test_update_model_params(self):
534560
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
535561
self.client.update_model_params(model)
@@ -549,6 +575,10 @@ def teardown_class(cls):
549575

550576

551577
@pytest.mark.slow
578+
@pytest.mark.skipif(
579+
_is_vllm_ge_014,
580+
reason="Skipping DP server test for vLLM>=0.14.0 (PR vllm#30739: DP for non-MoE/dense models no longer supported).",
581+
)
552582
@require_3_accelerators
553583
@require_vllm
554584
class TestVLLMClientServerDP(TrlTestCase):
@@ -635,6 +665,26 @@ def multiply(a: int, b: int) -> int:
635665
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
636666
assert "Multiplies two integers." in decoded_prompt
637667

668+
def test_generate_with_params(self):
669+
prompts = ["Hello, AI!", "Tell me a joke"]
670+
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
671+
"completion_ids"
672+
]
673+
674+
# Check that the output is a list
675+
assert isinstance(completion_ids, list)
676+
677+
# Check that the number of generated sequences is 2 times the number of prompts
678+
assert len(completion_ids) == 2 * len(prompts)
679+
680+
# Check that the generated sequences are lists of integers
681+
for seq in completion_ids:
682+
assert all(isinstance(tok, int) for tok in seq)
683+
684+
# Check that the length of the generated sequences is less than or equal to 32
685+
for seq in completion_ids:
686+
assert len(seq) <= 32
687+
638688
def test_update_model_params(self):
639689
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
640690
self.client.update_model_params(model)

trl/_compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _patch_vllm_disabled_tqdm() -> None:
8989
9090
- Bug introduced in https://github.com/vllm-project/vllm/pull/52
9191
- Fixed in https://github.com/vllm-project/vllm/pull/28471 (released in v0.11.1)
92-
- Since TRL currently supports vLLM v0.10.2-0.13.0, we patch it here
92+
- Since TRL currently supports vLLM v0.10.2-0.14.1, we patch it here
9393
- This can be removed when TRL requires vLLM>=0.11.1
9494
"""
9595
if _is_package_version_below("vllm", "0.11.1"):

trl/import_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ def is_uvicorn_available() -> bool:
109109
def is_vllm_available() -> bool:
110110
_vllm_available, _vllm_version = _is_package_available("vllm", return_version=True)
111111
if _vllm_available:
112-
if not (Version("0.10.2") <= Version(_vllm_version) <= Version("0.13.0")):
112+
if not (Version("0.10.2") <= Version(_vllm_version) <= Version("0.14.1")):
113113
warnings.warn(
114-
f"TRL currently supports vLLM versions from 0.10.2 to 0.13.0. You have version {_vllm_version} "
114+
f"TRL currently supports vLLM versions from 0.10.2 to 0.14.1. You have version {_vllm_version} "
115115
"installed. We recommend installing a supported version to avoid compatibility issues.",
116116
stacklevel=2,
117117
)

trl/scripts/vllm_serve.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ class ScriptArguments:
211211
tensor_parallel_size (`int`, *optional*, defaults to `1`):
212212
Number of tensor parallel workers to use.
213213
data_parallel_size (`int`, *optional*, defaults to `1`):
214-
Number of data parallel workers to use.
214+
Number of data parallel workers to use. For dense models, keep this at 1. Starting from vLLM `0.14.0`,
215+
setting this above `1` for dense models is no longer supported/useful and will error out (see vLLM PR
216+
#30739).
215217
host (`str`, *optional*, defaults to `"0.0.0.0"`):
216218
Host address to run the server on.
217219
port (`int`, *optional*, defaults to `8000`):
@@ -261,7 +263,11 @@ class ScriptArguments:
261263
)
262264
data_parallel_size: int = field(
263265
default=1,
264-
metadata={"help": "Number of data parallel workers to use."},
266+
metadata={
267+
"help": "Number of data parallel workers to use. For dense models, keep this at 1. Starting from vLLM "
268+
"`0.14.0`, setting this above `1` for dense models is no longer supported/useful and will error out (see "
269+
"vLLM PR #30739)."
270+
},
265271
)
266272
host: str = field(
267273
default="0.0.0.0",

0 commit comments

Comments
 (0)