Skip to content
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
f10285e
support prompts or token IDs in VLLMClient and update API request han…
qgallouedec Mar 5, 2026
7d2bb67
test
qgallouedec Mar 5, 2026
3b356ac
consistency
qgallouedec Mar 5, 2026
82c4508
fix
qgallouedec Mar 5, 2026
3ea2fcf
another fix
qgallouedec Mar 5, 2026
445f4ba
fix docstring
qgallouedec Mar 5, 2026
8c6c88d
Add support for multi-modal inputs in VLLMClient and vllm_serve
qgallouedec Mar 5, 2026
f617b2d
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 6, 2026
eaffd67
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 6, 2026
f3f6a5d
Move `rollout_func from `_generate_single_turn` to `_generate`
qgallouedec Mar 6, 2026
d417543
fix style
qgallouedec Mar 6, 2026
4b927d6
support multi-image
qgallouedec Mar 6, 2026
029fc1f
style
qgallouedec Mar 6, 2026
20b4039
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 6, 2026
b8e3912
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 6, 2026
07181cb
Fix handling of images in OnlineDPOTrainer to ensure proper structure…
qgallouedec Mar 7, 2026
6ff1e56
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 7, 2026
9f340e4
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 7, 2026
d138be7
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 7, 2026
09128d6
Move tokenization before vLLM generation call
qgallouedec Mar 7, 2026
7fd1711
Fix deadlock issue by ensuring images are always gathered in VLLMGene…
qgallouedec Mar 7, 2026
3ab04b0
Unify tokenization across all generation backends in _generate_single…
qgallouedec Mar 7, 2026
5d6d067
Extract tokenization out of _generate_single_turn into _tokenize_prompts
qgallouedec Mar 7, 2026
b4d2c34
Enhance multimodal input handling in GRPO and RLOO trainers by adding…
qgallouedec Mar 7, 2026
4922362
style
qgallouedec Mar 7, 2026
37c48b3
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 7, 2026
0a264a2
Fix tokenization padding issue in GRPOTrainer to handle unpadded inpu…
qgallouedec Mar 7, 2026
0aa0e30
style
qgallouedec Mar 7, 2026
b490357
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 7, 2026
8fecba1
align rloo
qgallouedec Mar 7, 2026
6c093dd
style
qgallouedec Mar 7, 2026
a9a91c7
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 7, 2026
f033e63
revert doc modif
qgallouedec Mar 9, 2026
5a1f609
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 9, 2026
1eb3540
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 9, 2026
498a564
Merge branch 'move-rollout-func' into vllm-generate-with-token-ids
qgallouedec Mar 9, 2026
be2ff99
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 9, 2026
5df2069
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 9, 2026
d3f7971
Merge branch 'main' into vllm-support-image-with-raw-token
qgallouedec Mar 9, 2026
319d52a
simplify multimodal
qgallouedec Mar 9, 2026
d5e1906
Merge branch 'main' into vllm-support-image-with-raw-token
qgallouedec Mar 9, 2026
4ccadcf
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 9, 2026
2a80df9
Merge branch 'move-rollout-func' into vllm-generate-with-token-ids
qgallouedec Mar 9, 2026
a0df552
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 9, 2026
3350588
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 9, 2026
0558dc9
Merge branch 'main' into move-rollout-func
qgallouedec Mar 9, 2026
6ebb681
Merge branch 'move-rollout-func' into vllm-generate-with-token-ids
qgallouedec Mar 9, 2026
93640e4
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 9, 2026
1c009b0
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 9, 2026
97a813b
Merge branch 'main' into vllm-generate-with-token-ids
qgallouedec Mar 10, 2026
83ab9bd
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 10, 2026
408fb2e
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
ade2831
Merge branch 'main' into vllm-generate-with-token-ids
qgallouedec Mar 10, 2026
258e0a8
Update trl/trainer/grpo_trainer.py
qgallouedec Mar 10, 2026
ef96048
Update trl/trainer/rloo_trainer.py
qgallouedec Mar 10, 2026
0ee6495
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 10, 2026
bb6dc69
Update trl/trainer/grpo_trainer.py
qgallouedec Mar 10, 2026
0effa0d
Update trl/trainer/rloo_trainer.py
qgallouedec Mar 10, 2026
fad1fdd
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
b35f250
Remove unused chat/tool configuration parameters from VLLM and RLOO t…
qgallouedec Mar 10, 2026
040e392
Update trl/generation/vllm_generation.py
qgallouedec Mar 10, 2026
ca2cae3
Update trl/trainer/rloo_trainer.py
qgallouedec Mar 10, 2026
fee553d
Merge branch 'main' into vllm-generate-with-token-ids
qgallouedec Mar 10, 2026
90df2de
Merge branch 'vllm-generate-with-token-ids' into unify-tokenization-g…
qgallouedec Mar 10, 2026
f36c0ea
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
fdaa90a
fix
qgallouedec Mar 10, 2026
6f10cd2
style
qgallouedec Mar 10, 2026
533c337
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
7e7e3b3
Merge branch 'main' into unify-tokenization-generate
qgallouedec Mar 10, 2026
31d8a0c
Merge branch 'unify-tokenization-generate' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
8b4f6af
Merge branch 'main' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
81cf273
Merge branch 'main' into extract-tokenize-prompts
qgallouedec Mar 10, 2026
918686b
Remove dead code: eliminate prompt tokenization logic from GRPOTraine…
qgallouedec Mar 10, 2026
9b8de83
remove unused extra_fields from _generate_single_turn return value
qgallouedec Mar 10, 2026
6c8f55c
style
qgallouedec Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 51 additions & 51 deletions trl/generation/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""vLLM-based generation backend for TRL trainers."""

import json
import logging
import math
import os
Expand All @@ -29,7 +28,6 @@
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, is_bitsandbytes_available
from transformers.utils import is_torch_mlu_available, is_torch_npu_available, is_torch_xpu_available

from ..data_utils import is_conversational, prepare_multimodal_messages_vllm
from ..extras.profiling import ProfilingContext
from ..import_utils import is_vllm_available
from ..trainer.utils import ensure_master_addr_port
Expand Down Expand Up @@ -528,13 +526,21 @@ def sync_weights(self):
elif self.mode == "colocate":
self.llm.reset_prefix_cache()

def generate(self, prompts: list, num_generations: int, profiler: ProfilingContext | None = None) -> tuple:
def generate(
self,
prompts: list[list[int]],
images: list[list] | None,
num_generations: int,
profiler: ProfilingContext | None = None,
) -> tuple:
"""Generate completions using vLLM.

Args:
prompts: List of prompts (strings or chat conversations)
num_generations: Number of generations per prompt
profiler: Optional profiler for performance tracking
prompts: List of token ID lists, one per prompt (already tokenized).
images: Optional list of image lists for VLM support. Each element is a list of PIL images for the
corresponding prompt, or `None` if no images for that prompt. `None` if no images at all.
num_generations: Number of generations per prompt.
profiler: Optional profiler for performance tracking.

Returns:
Tuple of (prompt_ids, completion_ids, logprobs, logprob_token_ids, extra_fields).
Expand Down Expand Up @@ -567,9 +573,6 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
min_p = self.min_p
repetition_penalty = self.repetition_penalty
max_completion_length = self.max_completion_length
chat_template_kwargs = self.chat_template_kwargs
tools = self.tools
chat_template = self.chat_template

# Wake up colocated vLLM weights if needed (idempotent if already awake from sync_weights)
if self.mode == "colocate" and self.enable_sleep_mode:
Expand All @@ -582,28 +585,21 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
# Non-CUDA vLLM backends (e.g., vllm-ascend's NPUWorkerV1), don't implement reload_weights
pass

if is_conversational({"prompt": prompts[0]}):
prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts]

# In vLLM, tool call arguments must be JSON strings. See https://github.com/vllm-project/vllm/pull/28820
for prompt in prompts: # iterate over each conversation
if is_conversational({"prompt": prompt}):
for message in prompt: # iterate over each message
if "tool_calls" in message: # check if message has tool calls
for call in message["tool_calls"]:
args_value = call["function"]["arguments"]
if isinstance(args_value, dict): # only convert dict → JSON string
call["function"]["arguments"] = json.dumps(args_value)

# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.mode == "server":
all_prompts = gather_object(prompts)
# Always gather images (even when None) to avoid deadlock: images may be None on some ranks
# and non-None on others in mixed datasets, and gather_object is a collective operation.
all_images = gather_object(images if images is not None else [None] * len(prompts))
if all(img is None for img in all_images):
all_images = None

if accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts[::num_generations]
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and
# generate num_generations outputs for each one. This is faster than generating outputs for each
# duplicate prompt individually.
ordered_set_of_prompt_ids = all_prompts[::num_generations]
ordered_set_of_images = all_images[::num_generations] if all_images is not None else None

sampling_params = {
"n": num_generations,
Expand All @@ -617,18 +613,12 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
"structured_outputs_regex": self.structured_outputs_regex,
"generation_kwargs": self.generation_kwargs,
}
with profiler: # TODO: profiling_context(trainer, "vLLM.generate"):
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
output = self.vllm_client.chat(
messages=ordered_set_of_prompts,
**sampling_params,
chat_template_kwargs=chat_template_kwargs,
tools=tools,
chat_template=chat_template,
)
else:
ordered_set_of_prompt_ids = self.processing_class(text=ordered_set_of_prompts)["input_ids"]
output = self.vllm_client.generate(prompts=ordered_set_of_prompt_ids, **sampling_params)
with profiler:
output = self.vllm_client.generate(
prompts=ordered_set_of_prompt_ids,
images=ordered_set_of_images,
**sampling_params,
)
# Extract required fields and collect any extra fields for reward functions
required_keys = {"prompt_ids", "completion_ids", "logprobs", "logprob_token_ids"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
Expand All @@ -647,7 +637,7 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
broadcast_object_list(obj_list, from_process=0)
all_prompt_ids, all_completion_ids, all_logprobs, all_logprob_token_ids, all_extra_fields = obj_list[0]

# vllm_client.generate/chat(n=num_generations) returns num_generations completions per prompt.
# vllm_client.generate(n=num_generations) returns num_generations completions per prompt.
# Duplicate prompt_ids to align with per-completion entries.
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)]

Expand Down Expand Up @@ -702,24 +692,34 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
gathered_prompts = [None for _ in range(self.tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
all_prompts = [p for sublist in gathered_prompts for p in sublist]
# Always gather images (even when None) to avoid deadlock: images may be None on some
# ranks and non-None on others in mixed datasets, and all_gather_object is collective.
local_images = images if images is not None else [None] * len(prompts)
gathered_images = [None for _ in range(self.tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_images, local_images, group=self.tp_group)
all_images = [img for sublist in gathered_images for img in sublist]
if all(img is None for img in all_images):
all_images = None
else:
all_prompts = prompts
all_images = images

if self.enable_sleep_mode:
self.llm.wake_up(tags=["kv_cache"])

with profiler: # TODO: profiling_context(trainer, "vLLM.generate"):
if is_conversational({"prompt": prompts[0]}):
all_outputs = self.llm.chat(
all_prompts,
sampling_params=sampling_params,
use_tqdm=False,
chat_template_kwargs=chat_template_kwargs,
tools=tools,
chat_template=chat_template,
)
else:
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)
# Build vLLM-compatible prompt inputs with token IDs and optional multi-modal data
vllm_prompts = []
if all_images is not None:
for ids, img_list in zip(all_prompts, all_images, strict=True):
row = {"prompt_token_ids": ids}
if img_list is not None:
row["multi_modal_data"] = {"image": img_list if len(img_list) > 1 else img_list[0]}
vllm_prompts.append(row)
else:
vllm_prompts = [{"prompt_token_ids": ids} for ids in all_prompts]

with profiler:
all_outputs = self.llm.generate(vllm_prompts, sampling_params=sampling_params, use_tqdm=False)

all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
Expand Down
Loading
Loading