Skip to content

Commit 64271a3

Browse files
authored
Merge pull request #1 from product-science/gm/validation
Gm/validation
2 parents ed6e907 + 9b6c2be commit 64271a3

File tree

5 files changed

+282
-221
lines changed

5 files changed

+282
-221
lines changed

vllm/entrypoints/openai/protocol.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -346,29 +346,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
346346
description=(
347347
"If specified, will override the default whitespace pattern "
348348
"for guided json decoding."))
349-
priority: int = Field(
350-
default=0,
351-
description=(
352-
"The priority of the request (lower means earlier handling; "
353-
"default: 0). Any priority other than 0 will raise an error "
354-
"if the served model does not use priority scheduling."))
355-
request_id: str = Field(
356-
default_factory=lambda: f"{random_uuid()}",
357-
description=(
358-
"The request_id related to this request. If the caller does "
359-
"not set it, a random_uuid will be generated. This id is used "
360-
"through out the inference process and return in response."))
361-
logits_processors: Optional[LogitsProcessors] = Field(
362-
default=None,
363-
description=(
364-
"A list of either qualified names of logits processors, or "
365-
"constructor objects, to apply when sampling. A constructor is "
366-
"a JSON object with a required 'qualname' field specifying the "
367-
"qualified name of the processor class/factory, and optional "
368-
"'args' and 'kwargs' fields containing positional and keyword "
369-
"arguments. For example: {'qualname': "
370-
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
371-
"{'param': 'value'}}."))
349+
enforced_str: Optional[str] = Field(default=None)
372350

373351
# doc: end-chat-completion-extra-params
374352

vllm/entrypoints/openai/serving_chat.py

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -207,49 +207,28 @@ async def create_chat_completion(
207207
# Schedule the request and get the result generator.
208208
generators: List[AsyncGenerator[RequestOutput, None]] = []
209209
try:
210-
for i, engine_prompt in enumerate(engine_prompts):
211-
sampling_params: Union[SamplingParams, BeamSearchParams]
212-
default_max_tokens = self.max_model_len - len(
213-
engine_prompt["prompt_token_ids"])
214-
# Build default sampling params
215-
default_sampling_params = (
216-
self.model_config.get_diff_sampling_param())
217-
if request.use_beam_search:
218-
sampling_params = request.to_beam_search_params(
219-
default_max_tokens, default_sampling_params)
220-
else:
221-
sampling_params = request.to_sampling_params(
222-
default_max_tokens,
223-
self.model_config.logits_processor_pattern,
224-
default_sampling_params)
225-
226-
self._log_inputs(request_id,
227-
request_prompts[i],
228-
params=sampling_params,
229-
lora_request=lora_request,
230-
prompt_adapter_request=prompt_adapter_request)
231-
232-
trace_headers = (None if raw_request is None else await
233-
self._get_trace_headers(raw_request.headers))
234-
235-
if isinstance(sampling_params, BeamSearchParams):
236-
generator = self.engine_client.beam_search(
237-
prompt=engine_prompt,
238-
request_id=request_id,
239-
params=sampling_params,
240-
)
241-
else:
242-
generator = self.engine_client.generate(
243-
engine_prompt,
244-
sampling_params,
245-
request_id,
246-
lora_request=lora_request,
247-
trace_headers=trace_headers,
248-
prompt_adapter_request=prompt_adapter_request,
249-
priority=request.priority,
250-
)
251-
252-
generators.append(generator)
210+
# Tokenize/detokenize depending on prompt format (string/token list)
211+
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
212+
request,
213+
prompt=prompt,
214+
add_special_tokens=request.add_special_tokens)
215+
sampling_params = request.to_sampling_params()
216+
if request.enforced_str:
217+
toks = self.tokenizer(request.enforced_str, add_special_tokens=False)
218+
sampling_params.enforce_token_ids = toks.input_ids + [self.tokenizer.eos_token_id]
219+
lora_request = self._maybe_get_lora(request)
220+
decoding_config = await self.engine.get_decoding_config()
221+
guided_decoding_backend = request.guided_decoding_backend \
222+
or decoding_config.guided_decoding_backend
223+
guided_decode_logits_processor = (
224+
await get_guided_decoding_logits_processor(
225+
guided_decoding_backend, request, await
226+
self.engine.get_tokenizer()))
227+
if guided_decode_logits_processor:
228+
if sampling_params.logits_processors is None:
229+
sampling_params.logits_processors = []
230+
sampling_params.logits_processors.append(
231+
guided_decode_logits_processor)
253232
except ValueError as e:
254233
# TODO: Use a vllm-specific Validation Error
255234
return self.create_error_response(str(e))

0 commit comments

Comments
 (0)