@@ -207,49 +207,28 @@ async def create_chat_completion(
207207 # Schedule the request and get the result generator.
208208 generators : List [AsyncGenerator [RequestOutput , None ]] = []
209209 try :
210- for i , engine_prompt in enumerate (engine_prompts ):
211- sampling_params : Union [SamplingParams , BeamSearchParams ]
212- default_max_tokens = self .max_model_len - len (
213- engine_prompt ["prompt_token_ids" ])
214- # Build default sampling params
215- default_sampling_params = (
216- self .model_config .get_diff_sampling_param ())
217- if request .use_beam_search :
218- sampling_params = request .to_beam_search_params (
219- default_max_tokens , default_sampling_params )
220- else :
221- sampling_params = request .to_sampling_params (
222- default_max_tokens ,
223- self .model_config .logits_processor_pattern ,
224- default_sampling_params )
225-
226- self ._log_inputs (request_id ,
227- request_prompts [i ],
228- params = sampling_params ,
229- lora_request = lora_request ,
230- prompt_adapter_request = prompt_adapter_request )
231-
232- trace_headers = (None if raw_request is None else await
233- self ._get_trace_headers (raw_request .headers ))
234-
235- if isinstance (sampling_params , BeamSearchParams ):
236- generator = self .engine_client .beam_search (
237- prompt = engine_prompt ,
238- request_id = request_id ,
239- params = sampling_params ,
240- )
241- else :
242- generator = self .engine_client .generate (
243- engine_prompt ,
244- sampling_params ,
245- request_id ,
246- lora_request = lora_request ,
247- trace_headers = trace_headers ,
248- prompt_adapter_request = prompt_adapter_request ,
249- priority = request .priority ,
250- )
251-
252- generators .append (generator )
210+ # Tokenize/detokenize depending on prompt format (string/token list)
211+ prompt_ids , prompt_text = self ._validate_prompt_and_tokenize (
212+ request ,
213+ prompt = prompt ,
214+ add_special_tokens = request .add_special_tokens )
215+ sampling_params = request .to_sampling_params ()
216+ if request .enforced_str :
217+ toks = self .tokenizer (request .enforced_str , add_special_tokens = False )
218+ sampling_params .enforce_token_ids = toks .input_ids + [self .tokenizer .eos_token_id ]
219+ lora_request = self ._maybe_get_lora (request )
220+ decoding_config = await self .engine .get_decoding_config ()
221+ guided_decoding_backend = request .guided_decoding_backend \
222+ or decoding_config .guided_decoding_backend
223+ guided_decode_logits_processor = (
224+ await get_guided_decoding_logits_processor (
225+ guided_decoding_backend , request , await
226+ self .engine .get_tokenizer ()))
227+ if guided_decode_logits_processor :
228+ if sampling_params .logits_processors is None :
229+ sampling_params .logits_processors = []
230+ sampling_params .logits_processors .append (
231+ guided_decode_logits_processor )
253232 except ValueError as e :
254233 # TODO: Use a vllm-specific Validation Error
255234 return self .create_error_response (str (e ))
0 commit comments