1414
1515"""vLLM-based generation backend for TRL trainers."""
1616
17- import json
1817import logging
1918import math
2019import os
2928from transformers import PreTrainedModel , PreTrainedTokenizerBase , ProcessorMixin , is_bitsandbytes_available
3029from transformers .utils import is_torch_mlu_available , is_torch_npu_available , is_torch_xpu_available
3130
32- from ..data_utils import is_conversational , prepare_multimodal_messages_vllm
3331from ..extras .profiling import ProfilingContext
3432from ..import_utils import is_vllm_available
3533from ..trainer .utils import ensure_master_addr_port
@@ -245,10 +243,6 @@ def __init__(
245243 max_completion_length : int = 16 ,
246244 logprobs : int | None = 0 ,
247245 generation_kwargs : dict | None = None ,
248- # Chat/tool configuration
249- chat_template : str | None = None ,
250- chat_template_kwargs : dict | None = None ,
251- tools : list | None = None ,
252246 ):
253247 self .model = model
254248 self .accelerator = accelerator
@@ -284,11 +278,6 @@ def __init__(
284278 self .logprobs = logprobs
285279 self .generation_kwargs = generation_kwargs or {}
286280
287- # Chat/tool configuration
288- self .chat_template = chat_template
289- self .chat_template_kwargs = chat_template_kwargs or {}
290- self .tools = tools
291-
292281 self ._init_vllm ()
293282
294283 def _init_vllm (self ):
@@ -528,13 +517,21 @@ def sync_weights(self):
528517 elif self .mode == "colocate" :
529518 self .llm .reset_prefix_cache ()
530519
531- def generate (self , prompts : list , num_generations : int , profiler : ProfilingContext | None = None ) -> tuple :
520+ def generate (
521+ self ,
522+ prompts : list [list [int ]],
523+ images : list [list | None ] | None ,
524+ num_generations : int ,
525+ profiler : ProfilingContext | None = None ,
526+ ) -> tuple :
532527 """Generate completions using vLLM.
533528
534529 Args:
535- prompts: List of prompts (strings or chat conversations)
536- num_generations: Number of generations per prompt
537- profiler: Optional profiler for performance tracking
530+ prompts: List of token ID lists, one per prompt (already tokenized).
531+ images: Optional list of image lists for VLM support. Each element is a list of PIL images for the
532+ corresponding prompt, or `None` if no images for that prompt. `None` if no images at all.
533+ num_generations: Number of generations per prompt.
534+ profiler: Optional profiler for performance tracking.
538535
539536 Returns:
540537 Tuple of (prompt_ids, completion_ids, logprobs, logprob_token_ids, extra_fields).
@@ -567,9 +564,6 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
567564 min_p = self .min_p
568565 repetition_penalty = self .repetition_penalty
569566 max_completion_length = self .max_completion_length
570- chat_template_kwargs = self .chat_template_kwargs
571- tools = self .tools
572- chat_template = self .chat_template
573567
574568 # Wake up colocated vLLM weights if needed (idempotent if already awake from sync_weights)
575569 if self .mode == "colocate" and self .enable_sleep_mode :
@@ -582,28 +576,21 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
582576 # Non-CUDA vLLM backends (e.g., vllm-ascend's NPUWorkerV1), don't implement reload_weights
583577 pass
584578
585- if is_conversational ({"prompt" : prompts [0 ]}):
586- prompts = [prepare_multimodal_messages_vllm (prompt ) for prompt in prompts ]
587-
588- # In vLLM, tool call arguments must be JSON strings. See https://github.com/vllm-project/vllm/pull/28820
589- for prompt in prompts : # iterate over each conversation
590- if is_conversational ({"prompt" : prompt }):
591- for message in prompt : # iterate over each message
592- if "tool_calls" in message : # check if message has tool calls
593- for call in message ["tool_calls" ]:
594- args_value = call ["function" ]["arguments" ]
595- if isinstance (args_value , dict ): # only convert dict → JSON string
596- call ["function" ]["arguments" ] = json .dumps (args_value )
597-
598579 # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
599580 if self .mode == "server" :
600581 all_prompts = gather_object (prompts )
582+ # Always gather images (even when None) to avoid deadlock: images may be None on some ranks
583+ # and non-None on others in mixed datasets, and gather_object is a collective operation.
584+ all_images = gather_object (images if images is not None else [None ] * len (prompts ))
585+ if all (img is None for img in all_images ):
586+ all_images = None
601587
602588 if accelerator .is_main_process :
603- # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
604- # num_generations outputs for each one. This is faster than generating outputs for each duplicate
605- # prompt individually.
606- ordered_set_of_prompts = all_prompts [::num_generations ]
589+ # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and
590+ # generate num_generations outputs for each one. This is faster than generating outputs for each
591+ # duplicate prompt individually.
592+ ordered_set_of_prompt_ids = all_prompts [::num_generations ]
593+ ordered_set_of_images = all_images [::num_generations ] if all_images is not None else None
607594
608595 sampling_params = {
609596 "n" : num_generations ,
@@ -617,18 +604,12 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
617604 "structured_outputs_regex" : self .structured_outputs_regex ,
618605 "generation_kwargs" : self .generation_kwargs ,
619606 }
620- with profiler : # TODO: profiling_context(trainer, "vLLM.generate"):
621- if is_conversational ({"prompt" : ordered_set_of_prompts [0 ]}):
622- output = self .vllm_client .chat (
623- messages = ordered_set_of_prompts ,
624- ** sampling_params ,
625- chat_template_kwargs = chat_template_kwargs ,
626- tools = tools ,
627- chat_template = chat_template ,
628- )
629- else :
630- ordered_set_of_prompt_ids = self .processing_class (text = ordered_set_of_prompts )["input_ids" ]
631- output = self .vllm_client .generate (prompts = ordered_set_of_prompt_ids , ** sampling_params )
607+ with profiler :
608+ output = self .vllm_client .generate (
609+ prompts = ordered_set_of_prompt_ids ,
610+ images = ordered_set_of_images ,
611+ ** sampling_params ,
612+ )
632613 # Extract required fields and collect any extra fields for reward functions
633614 required_keys = {"prompt_ids" , "completion_ids" , "logprobs" , "logprob_token_ids" }
634615 extra_fields = {k : v for k , v in output .items () if k not in required_keys }
@@ -647,7 +628,7 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
647628 broadcast_object_list (obj_list , from_process = 0 )
648629 all_prompt_ids , all_completion_ids , all_logprobs , all_logprob_token_ids , all_extra_fields = obj_list [0 ]
649630
650- # vllm_client.generate/chat (n=num_generations) returns num_generations completions per prompt.
631+ # vllm_client.generate(n=num_generations) returns num_generations completions per prompt.
651632 # Duplicate prompt_ids to align with per-completion entries.
652633 all_prompt_ids = [ids for ids in all_prompt_ids for _ in range (num_generations )]
653634
@@ -702,24 +683,34 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
702683 gathered_prompts = [None for _ in range (self .tensor_parallel_size )]
703684 torch .distributed .all_gather_object (gathered_prompts , prompts , group = self .tp_group )
704685 all_prompts = [p for sublist in gathered_prompts for p in sublist ]
686+ # Always gather images (even when None) to avoid deadlock: images may be None on some
687+ # ranks and non-None on others in mixed datasets, and all_gather_object is collective.
688+ local_images = images if images is not None else [None ] * len (prompts )
689+ gathered_images = [None for _ in range (self .tensor_parallel_size )]
690+ torch .distributed .all_gather_object (gathered_images , local_images , group = self .tp_group )
691+ all_images = [img for sublist in gathered_images for img in sublist ]
692+ if all (img is None for img in all_images ):
693+ all_images = None
705694 else :
706695 all_prompts = prompts
696+ all_images = images
707697
708698 if self .enable_sleep_mode :
709699 self .llm .wake_up (tags = ["kv_cache" ])
710700
711- with profiler : # TODO: profiling_context(trainer, "vLLM.generate"):
712- if is_conversational ({"prompt" : prompts [0 ]}):
713- all_outputs = self .llm .chat (
714- all_prompts ,
715- sampling_params = sampling_params ,
716- use_tqdm = False ,
717- chat_template_kwargs = chat_template_kwargs ,
718- tools = tools ,
719- chat_template = chat_template ,
720- )
721- else :
722- all_outputs = self .llm .generate (all_prompts , sampling_params = sampling_params , use_tqdm = False )
701+ # Build vLLM-compatible prompt inputs with token IDs and optional multi-modal data
702+ vllm_prompts = []
703+ if all_images is not None :
704+ for ids , img_list in zip (all_prompts , all_images , strict = True ):
705+ row = {"prompt_token_ids" : ids }
706+ if img_list is not None :
707+ row ["multi_modal_data" ] = {"image" : img_list if len (img_list ) > 1 else img_list [0 ]}
708+ vllm_prompts .append (row )
709+ else :
710+ vllm_prompts = [{"prompt_token_ids" : ids } for ids in all_prompts ]
711+
712+ with profiler :
713+ all_outputs = self .llm .generate (vllm_prompts , sampling_params = sampling_params , use_tqdm = False )
723714
724715 all_prompt_ids = [output .prompt_token_ids for output in all_outputs ]
725716 all_completion_ids = [output .token_ids for outputs in all_outputs for output in outputs .outputs ]
0 commit comments