Skip to content

Commit b77f36f

Browse files
[GRPO/RLOO] Tokenize before vLLM generation call (#5238)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
1 parent bd5307e commit b77f36f

File tree

3 files changed

+119
-69
lines changed

3 files changed

+119
-69
lines changed

trl/generation/vllm_generation.py

Lines changed: 51 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""vLLM-based generation backend for TRL trainers."""
1616

17-
import json
1817
import logging
1918
import math
2019
import os
@@ -29,7 +28,6 @@
2928
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, is_bitsandbytes_available
3029
from transformers.utils import is_torch_mlu_available, is_torch_npu_available, is_torch_xpu_available
3130

32-
from ..data_utils import is_conversational, prepare_multimodal_messages_vllm
3331
from ..extras.profiling import ProfilingContext
3432
from ..import_utils import is_vllm_available
3533
from ..trainer.utils import ensure_master_addr_port
@@ -245,10 +243,6 @@ def __init__(
245243
max_completion_length: int = 16,
246244
logprobs: int | None = 0,
247245
generation_kwargs: dict | None = None,
248-
# Chat/tool configuration
249-
chat_template: str | None = None,
250-
chat_template_kwargs: dict | None = None,
251-
tools: list | None = None,
252246
):
253247
self.model = model
254248
self.accelerator = accelerator
@@ -284,11 +278,6 @@ def __init__(
284278
self.logprobs = logprobs
285279
self.generation_kwargs = generation_kwargs or {}
286280

287-
# Chat/tool configuration
288-
self.chat_template = chat_template
289-
self.chat_template_kwargs = chat_template_kwargs or {}
290-
self.tools = tools
291-
292281
self._init_vllm()
293282

294283
def _init_vllm(self):
@@ -528,13 +517,21 @@ def sync_weights(self):
528517
elif self.mode == "colocate":
529518
self.llm.reset_prefix_cache()
530519

531-
def generate(self, prompts: list, num_generations: int, profiler: ProfilingContext | None = None) -> tuple:
520+
def generate(
521+
self,
522+
prompts: list[list[int]],
523+
images: list[list | None] | None,
524+
num_generations: int,
525+
profiler: ProfilingContext | None = None,
526+
) -> tuple:
532527
"""Generate completions using vLLM.
533528
534529
Args:
535-
prompts: List of prompts (strings or chat conversations)
536-
num_generations: Number of generations per prompt
537-
profiler: Optional profiler for performance tracking
530+
prompts: List of token ID lists, one per prompt (already tokenized).
531+
images: Optional list of image lists for VLM support. Each element is a list of PIL images for the
532+
corresponding prompt, or `None` if no images for that prompt. `None` if no images at all.
533+
num_generations: Number of generations per prompt.
534+
profiler: Optional profiler for performance tracking.
538535
539536
Returns:
540537
Tuple of (prompt_ids, completion_ids, logprobs, logprob_token_ids, extra_fields).
@@ -567,9 +564,6 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
567564
min_p = self.min_p
568565
repetition_penalty = self.repetition_penalty
569566
max_completion_length = self.max_completion_length
570-
chat_template_kwargs = self.chat_template_kwargs
571-
tools = self.tools
572-
chat_template = self.chat_template
573567

574568
# Wake up colocated vLLM weights if needed (idempotent if already awake from sync_weights)
575569
if self.mode == "colocate" and self.enable_sleep_mode:
@@ -582,28 +576,21 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
582576
# Non-CUDA vLLM backends (e.g., vllm-ascend's NPUWorkerV1), don't implement reload_weights
583577
pass
584578

585-
if is_conversational({"prompt": prompts[0]}):
586-
prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts]
587-
588-
# In vLLM, tool call arguments must be JSON strings. See https://github.com/vllm-project/vllm/pull/28820
589-
for prompt in prompts: # iterate over each conversation
590-
if is_conversational({"prompt": prompt}):
591-
for message in prompt: # iterate over each message
592-
if "tool_calls" in message: # check if message has tool calls
593-
for call in message["tool_calls"]:
594-
args_value = call["function"]["arguments"]
595-
if isinstance(args_value, dict): # only convert dict → JSON string
596-
call["function"]["arguments"] = json.dumps(args_value)
597-
598579
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
599580
if self.mode == "server":
600581
all_prompts = gather_object(prompts)
582+
# Always gather images (even when None) to avoid deadlock: images may be None on some ranks
583+
# and non-None on others in mixed datasets, and gather_object is a collective operation.
584+
all_images = gather_object(images if images is not None else [None] * len(prompts))
585+
if all(img is None for img in all_images):
586+
all_images = None
601587

602588
if accelerator.is_main_process:
603-
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
604-
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
605-
# prompt individually.
606-
ordered_set_of_prompts = all_prompts[::num_generations]
589+
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and
590+
# generate num_generations outputs for each one. This is faster than generating outputs for each
591+
# duplicate prompt individually.
592+
ordered_set_of_prompt_ids = all_prompts[::num_generations]
593+
ordered_set_of_images = all_images[::num_generations] if all_images is not None else None
607594

608595
sampling_params = {
609596
"n": num_generations,
@@ -617,18 +604,12 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
617604
"structured_outputs_regex": self.structured_outputs_regex,
618605
"generation_kwargs": self.generation_kwargs,
619606
}
620-
with profiler: # TODO: profiling_context(trainer, "vLLM.generate"):
621-
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
622-
output = self.vllm_client.chat(
623-
messages=ordered_set_of_prompts,
624-
**sampling_params,
625-
chat_template_kwargs=chat_template_kwargs,
626-
tools=tools,
627-
chat_template=chat_template,
628-
)
629-
else:
630-
ordered_set_of_prompt_ids = self.processing_class(text=ordered_set_of_prompts)["input_ids"]
631-
output = self.vllm_client.generate(prompts=ordered_set_of_prompt_ids, **sampling_params)
607+
with profiler:
608+
output = self.vllm_client.generate(
609+
prompts=ordered_set_of_prompt_ids,
610+
images=ordered_set_of_images,
611+
**sampling_params,
612+
)
632613
# Extract required fields and collect any extra fields for reward functions
633614
required_keys = {"prompt_ids", "completion_ids", "logprobs", "logprob_token_ids"}
634615
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
@@ -647,7 +628,7 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
647628
broadcast_object_list(obj_list, from_process=0)
648629
all_prompt_ids, all_completion_ids, all_logprobs, all_logprob_token_ids, all_extra_fields = obj_list[0]
649630

650-
# vllm_client.generate/chat(n=num_generations) returns num_generations completions per prompt.
631+
# vllm_client.generate(n=num_generations) returns num_generations completions per prompt.
651632
# Duplicate prompt_ids to align with per-completion entries.
652633
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)]
653634

@@ -702,24 +683,34 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
702683
gathered_prompts = [None for _ in range(self.tensor_parallel_size)]
703684
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
704685
all_prompts = [p for sublist in gathered_prompts for p in sublist]
686+
# Always gather images (even when None) to avoid deadlock: images may be None on some
687+
# ranks and non-None on others in mixed datasets, and all_gather_object is collective.
688+
local_images = images if images is not None else [None] * len(prompts)
689+
gathered_images = [None for _ in range(self.tensor_parallel_size)]
690+
torch.distributed.all_gather_object(gathered_images, local_images, group=self.tp_group)
691+
all_images = [img for sublist in gathered_images for img in sublist]
692+
if all(img is None for img in all_images):
693+
all_images = None
705694
else:
706695
all_prompts = prompts
696+
all_images = images
707697

708698
if self.enable_sleep_mode:
709699
self.llm.wake_up(tags=["kv_cache"])
710700

711-
with profiler: # TODO: profiling_context(trainer, "vLLM.generate"):
712-
if is_conversational({"prompt": prompts[0]}):
713-
all_outputs = self.llm.chat(
714-
all_prompts,
715-
sampling_params=sampling_params,
716-
use_tqdm=False,
717-
chat_template_kwargs=chat_template_kwargs,
718-
tools=tools,
719-
chat_template=chat_template,
720-
)
721-
else:
722-
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)
701+
# Build vLLM-compatible prompt inputs with token IDs and optional multi-modal data
702+
vllm_prompts = []
703+
if all_images is not None:
704+
for ids, img_list in zip(all_prompts, all_images, strict=True):
705+
row = {"prompt_token_ids": ids}
706+
if img_list is not None:
707+
row["multi_modal_data"] = {"image": img_list if len(img_list) > 1 else img_list[0]}
708+
vllm_prompts.append(row)
709+
else:
710+
vllm_prompts = [{"prompt_token_ids": ids} for ids in all_prompts]
711+
712+
with profiler:
713+
all_outputs = self.llm.generate(vllm_prompts, sampling_params=sampling_params, use_tqdm=False)
723714

724715
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
725716
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]

trl/trainer/grpo_trainer.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -741,10 +741,6 @@ def cast_outputs_to_original_dtype(module, args, output):
741741
max_completion_length=self.max_completion_length,
742742
logprobs=0, # we only need the generated token logprobs for the importance sampling correction
743743
generation_kwargs=args.generation_kwargs,
744-
# Chat/tool configuration
745-
chat_template=self.chat_template,
746-
chat_template_kwargs=self.chat_template_kwargs,
747-
tools=self.tools,
748744
)
749745
self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
750746
else:
@@ -1226,10 +1222,43 @@ def _generate_single_turn(self, prompts: list):
12261222
self.vllm_generation.sync_weights()
12271223
self._last_loaded_step = self.state.global_step
12281224

1229-
# Generate using vLLM
1225+
# Tokenize prompts and extract images (for VLM) before calling vLLM
1226+
if is_conversational({"prompt": prompts[0]}):
1227+
# Extract images from messages for VLM support
1228+
images = []
1229+
has_images = False
1230+
for prompt in prompts:
1231+
prompt_images = []
1232+
for message in prompt:
1233+
if isinstance(message["content"], list):
1234+
for part in message["content"]:
1235+
if part["type"] == "image":
1236+
prompt_images.append(part["image"])
1237+
has_images = True
1238+
images.append(prompt_images if prompt_images else None)
1239+
images = images if has_images else None
1240+
1241+
tokenized = self.processing_class.apply_chat_template(
1242+
conversation=prompts,
1243+
tools=self.tools,
1244+
chat_template=self.chat_template,
1245+
add_generation_prompt=True,
1246+
tokenize=True,
1247+
return_dict=True,
1248+
**self.chat_template_kwargs,
1249+
)
1250+
prompt_token_ids = tokenized["input_ids"]
1251+
else:
1252+
prompt_token_ids = self.processing_class(text=prompts)["input_ids"]
1253+
images = None
1254+
1255+
# Generate using vLLM with raw token IDs
12301256
num_generations = self.num_generations if mode == "train" else self.num_generations_eval
12311257
prompt_ids, completion_ids, logprobs, _, extra_fields = self.vllm_generation.generate(
1232-
prompts=prompts, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate")
1258+
prompts=prompt_token_ids,
1259+
images=images,
1260+
num_generations=num_generations,
1261+
profiler=profiling_context(self, "vLLM.generate"),
12331262
)
12341263
# vLLM returns per-token top-k logprobs; keep only the top-1 (sampled token) logprob
12351264
logprobs = [[lp[0] for lp in seq] for seq in logprobs]

trl/trainer/rloo_trainer.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,6 @@ def __init__(
525525
max_completion_length=self.max_completion_length,
526526
logprobs=None, # we don't need logprobs from vLLM in RLOO
527527
generation_kwargs=args.generation_kwargs,
528-
# Chat/tool configuration
529-
chat_template_kwargs=self.chat_template_kwargs,
530528
)
531529
self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
532530
else:
@@ -898,10 +896,42 @@ def _generate_single_turn(self, prompts: list):
898896
self.vllm_generation.sync_weights()
899897
self._last_loaded_step = self.state.global_step
900898

899+
# Tokenize prompts and extract images (for VLM) before calling vLLM
900+
if is_conversational({"prompt": prompts[0]}):
901+
# Extract images from messages for VLM support
902+
images = []
903+
has_images = False
904+
for prompt in prompts:
905+
prompt_images = []
906+
for message in prompt:
907+
if isinstance(message["content"], list):
908+
for part in message["content"]:
909+
if part["type"] == "image":
910+
prompt_images.append(part["image"])
911+
has_images = True
912+
images.append(prompt_images if prompt_images else None)
913+
images = images if has_images else None
914+
915+
# RLOO does not support tools; omit tools/chat_template args
916+
tokenized = self.processing_class.apply_chat_template(
917+
conversation=prompts,
918+
add_generation_prompt=True,
919+
tokenize=True,
920+
return_dict=True,
921+
**self.chat_template_kwargs,
922+
)
923+
prompt_token_ids = tokenized["input_ids"]
924+
else:
925+
prompt_token_ids = self.processing_class(text=prompts)["input_ids"]
926+
images = None
927+
901928
# Generate using vLLM (note: RLOO doesn't use logprobs from generation, so we ignore them)
902929
num_generations = self.num_generations if mode == "train" else self.num_generations_eval
903930
prompt_ids, completion_ids, _, _, _ = self.vllm_generation.generate(
904-
prompts=prompts, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate")
931+
prompts=prompt_token_ids,
932+
images=images,
933+
num_generations=num_generations,
934+
profiler=profiling_context(self, "vLLM.generate"),
905935
)
906936

907937
elif self.use_transformers_paged:

0 commit comments

Comments
 (0)