Skip to content

Commit 8ef1189

Browse files
authored
chore: update the vLLM inference impl to use OpenAIMixin for openai-compat functions (#3404)
# What does this PR do? update vLLM inference provider to use OpenAIMixin for openai-compat functions inference recordings from Qwen3-0.6B and vLLM 0.8.3 - ``` docker run --gpus all -v ~/.cache/huggingface:/root/.cache/huggingface -p 8000:8000 --ipc=host \ vllm/vllm-openai:latest \ --model Qwen/Qwen3-0.6B --enable-auto-tool-choice --tool-call-parser hermes ``` ## Test Plan ``` ./scripts/integration-tests.sh --stack-config server:ci-tests --setup vllm --subdirs inference ```
1 parent d15368a commit 8ef1189

File tree

3 files changed

+44
-202
lines changed

3 files changed

+44
-202
lines changed

llama_stack/providers/remote/inference/vllm/vllm.py

Lines changed: 11 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66
import json
7-
from collections.abc import AsyncGenerator, AsyncIterator
7+
from collections.abc import AsyncGenerator
88
from typing import Any
99

1010
import httpx
@@ -38,13 +38,6 @@
3838
LogProbConfig,
3939
Message,
4040
ModelStore,
41-
OpenAIChatCompletion,
42-
OpenAICompletion,
43-
OpenAIEmbeddingData,
44-
OpenAIEmbeddingsResponse,
45-
OpenAIEmbeddingUsage,
46-
OpenAIMessageParam,
47-
OpenAIResponseFormatParam,
4841
ResponseFormat,
4942
SamplingParams,
5043
TextTruncation,
@@ -71,11 +64,11 @@
7164
convert_message_to_openai_dict,
7265
convert_tool_call,
7366
get_sampling_options,
74-
prepare_openai_completion_params,
7567
process_chat_completion_stream_response,
7668
process_completion_response,
7769
process_completion_stream_response,
7870
)
71+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
7972
from llama_stack.providers.utils.inference.prompt_adapter import (
8073
completion_request_to_prompt,
8174
content_has_media,
@@ -288,15 +281,14 @@ async def _process_vllm_chat_completion_stream_response(
288281
yield c
289282

290283

291-
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
284+
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
292285
# automatically set by the resolver when instantiating the provider
293286
__provider_id__: str
294287
model_store: ModelStore | None = None
295288

296289
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
297290
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
298291
self.config = config
299-
self.client = None
300292

301293
async def initialize(self) -> None:
302294
if not self.config.url:
@@ -308,8 +300,6 @@ async def should_refresh_models(self) -> bool:
308300
return self.config.refresh_models
309301

310302
async def list_models(self) -> list[Model] | None:
311-
self._lazy_initialize_client()
312-
assert self.client is not None # mypy
313303
models = []
314304
async for m in self.client.models.list():
315305
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
@@ -340,8 +330,7 @@ async def health(self) -> HealthResponse:
340330
HealthResponse: A dictionary containing the health status.
341331
"""
342332
try:
343-
client = self._create_client() if self.client is None else self.client
344-
_ = [m async for m in client.models.list()] # Ensure the client is initialized
333+
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
345334
return HealthResponse(status=HealthStatus.OK)
346335
except Exception as e:
347336
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@@ -351,19 +340,14 @@ async def _get_model(self, model_id: str) -> Model:
351340
raise ValueError("Model store not set")
352341
return await self.model_store.get_model(model_id)
353342

354-
def _lazy_initialize_client(self):
355-
if self.client is not None:
356-
return
343+
def get_api_key(self):
344+
return self.config.api_token
357345

358-
log.info(f"Initializing vLLM client with base_url={self.config.url}")
359-
self.client = self._create_client()
346+
def get_base_url(self):
347+
return self.config.url
360348

361-
def _create_client(self):
362-
return AsyncOpenAI(
363-
base_url=self.config.url,
364-
api_key=self.config.api_token,
365-
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
366-
)
349+
def get_extra_client_params(self):
350+
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
367351

368352
async def completion(
369353
self,
@@ -374,7 +358,6 @@ async def completion(
374358
stream: bool | None = False,
375359
logprobs: LogProbConfig | None = None,
376360
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
377-
self._lazy_initialize_client()
378361
if sampling_params is None:
379362
sampling_params = SamplingParams()
380363
model = await self._get_model(model_id)
@@ -406,7 +389,6 @@ async def chat_completion(
406389
logprobs: LogProbConfig | None = None,
407390
tool_config: ToolConfig | None = None,
408391
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
409-
self._lazy_initialize_client()
410392
if sampling_params is None:
411393
sampling_params = SamplingParams()
412394
model = await self._get_model(model_id)
@@ -479,16 +461,12 @@ async def _stream_completion(
479461
yield chunk
480462

481463
async def register_model(self, model: Model) -> Model:
482-
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
483-
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
484-
# Changing this may lead to unpredictable behavior.
485-
client = self._create_client() if self.client is None else self.client
486464
try:
487465
model = await self.register_helper.register_model(model)
488466
except ValueError:
489467
pass # Ignore statically unknown model, will check live listing
490468
try:
491-
res = await client.models.list()
469+
res = await self.client.models.list()
492470
except APIConnectionError as e:
493471
raise ValueError(
494472
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
@@ -543,8 +521,6 @@ async def embeddings(
543521
output_dimension: int | None = None,
544522
task_type: EmbeddingTaskType | None = None,
545523
) -> EmbeddingsResponse:
546-
self._lazy_initialize_client()
547-
assert self.client is not None
548524
model = await self._get_model(model_id)
549525

550526
kwargs = {}
@@ -560,154 +536,3 @@ async def embeddings(
560536

561537
embeddings = [data.embedding for data in response.data]
562538
return EmbeddingsResponse(embeddings=embeddings)
563-
564-
async def openai_embeddings(
565-
self,
566-
model: str,
567-
input: str | list[str],
568-
encoding_format: str | None = "float",
569-
dimensions: int | None = None,
570-
user: str | None = None,
571-
) -> OpenAIEmbeddingsResponse:
572-
self._lazy_initialize_client()
573-
assert self.client is not None
574-
model_obj = await self._get_model(model)
575-
assert model_obj.model_type == ModelType.embedding
576-
577-
# Convert input to list if it's a string
578-
input_list = [input] if isinstance(input, str) else input
579-
580-
# Call vLLM embeddings endpoint with encoding_format
581-
response = await self.client.embeddings.create(
582-
model=model_obj.provider_resource_id,
583-
input=input_list,
584-
dimensions=dimensions,
585-
encoding_format=encoding_format,
586-
)
587-
588-
# Convert response to OpenAI format
589-
data = [
590-
OpenAIEmbeddingData(
591-
embedding=embedding_data.embedding,
592-
index=i,
593-
)
594-
for i, embedding_data in enumerate(response.data)
595-
]
596-
597-
# Not returning actual token usage since vLLM doesn't provide it
598-
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
599-
600-
return OpenAIEmbeddingsResponse(
601-
data=data,
602-
model=model_obj.provider_resource_id,
603-
usage=usage,
604-
)
605-
606-
async def openai_completion(
607-
self,
608-
model: str,
609-
prompt: str | list[str] | list[int] | list[list[int]],
610-
best_of: int | None = None,
611-
echo: bool | None = None,
612-
frequency_penalty: float | None = None,
613-
logit_bias: dict[str, float] | None = None,
614-
logprobs: bool | None = None,
615-
max_tokens: int | None = None,
616-
n: int | None = None,
617-
presence_penalty: float | None = None,
618-
seed: int | None = None,
619-
stop: str | list[str] | None = None,
620-
stream: bool | None = None,
621-
stream_options: dict[str, Any] | None = None,
622-
temperature: float | None = None,
623-
top_p: float | None = None,
624-
user: str | None = None,
625-
guided_choice: list[str] | None = None,
626-
prompt_logprobs: int | None = None,
627-
suffix: str | None = None,
628-
) -> OpenAICompletion:
629-
self._lazy_initialize_client()
630-
model_obj = await self._get_model(model)
631-
632-
extra_body: dict[str, Any] = {}
633-
if prompt_logprobs is not None and prompt_logprobs >= 0:
634-
extra_body["prompt_logprobs"] = prompt_logprobs
635-
if guided_choice:
636-
extra_body["guided_choice"] = guided_choice
637-
638-
params = await prepare_openai_completion_params(
639-
model=model_obj.provider_resource_id,
640-
prompt=prompt,
641-
best_of=best_of,
642-
echo=echo,
643-
frequency_penalty=frequency_penalty,
644-
logit_bias=logit_bias,
645-
logprobs=logprobs,
646-
max_tokens=max_tokens,
647-
n=n,
648-
presence_penalty=presence_penalty,
649-
seed=seed,
650-
stop=stop,
651-
stream=stream,
652-
stream_options=stream_options,
653-
temperature=temperature,
654-
top_p=top_p,
655-
user=user,
656-
extra_body=extra_body,
657-
)
658-
return await self.client.completions.create(**params) # type: ignore
659-
660-
async def openai_chat_completion(
661-
self,
662-
model: str,
663-
messages: list[OpenAIMessageParam],
664-
frequency_penalty: float | None = None,
665-
function_call: str | dict[str, Any] | None = None,
666-
functions: list[dict[str, Any]] | None = None,
667-
logit_bias: dict[str, float] | None = None,
668-
logprobs: bool | None = None,
669-
max_completion_tokens: int | None = None,
670-
max_tokens: int | None = None,
671-
n: int | None = None,
672-
parallel_tool_calls: bool | None = None,
673-
presence_penalty: float | None = None,
674-
response_format: OpenAIResponseFormatParam | None = None,
675-
seed: int | None = None,
676-
stop: str | list[str] | None = None,
677-
stream: bool | None = None,
678-
stream_options: dict[str, Any] | None = None,
679-
temperature: float | None = None,
680-
tool_choice: str | dict[str, Any] | None = None,
681-
tools: list[dict[str, Any]] | None = None,
682-
top_logprobs: int | None = None,
683-
top_p: float | None = None,
684-
user: str | None = None,
685-
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
686-
self._lazy_initialize_client()
687-
model_obj = await self._get_model(model)
688-
params = await prepare_openai_completion_params(
689-
model=model_obj.provider_resource_id,
690-
messages=messages,
691-
frequency_penalty=frequency_penalty,
692-
function_call=function_call,
693-
functions=functions,
694-
logit_bias=logit_bias,
695-
logprobs=logprobs,
696-
max_completion_tokens=max_completion_tokens,
697-
max_tokens=max_tokens,
698-
n=n,
699-
parallel_tool_calls=parallel_tool_calls,
700-
presence_penalty=presence_penalty,
701-
response_format=response_format,
702-
seed=seed,
703-
stop=stop,
704-
stream=stream,
705-
stream_options=stream_options,
706-
temperature=temperature,
707-
tool_choice=tool_choice,
708-
tools=tools,
709-
top_logprobs=top_logprobs,
710-
top_p=top_p,
711-
user=user,
712-
)
713-
return await self.client.chat.completions.create(**params) # type: ignore

llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,17 @@ def get_base_url(self) -> str:
6767
"""
6868
pass
6969

70+
def get_extra_client_params(self) -> dict[str, Any]:
71+
"""
72+
Get any extra parameters to pass to the AsyncOpenAI client.
73+
74+
Child classes can override this method to provide additional parameters
75+
such as timeout settings, proxies, etc.
76+
77+
:return: A dictionary of extra parameters
78+
"""
79+
return {}
80+
7081
@property
7182
def client(self) -> AsyncOpenAI:
7283
"""
@@ -78,6 +89,7 @@ def client(self) -> AsyncOpenAI:
7889
return AsyncOpenAI(
7990
api_key=self.get_api_key(),
8091
base_url=self.get_base_url(),
92+
**self.get_extra_client_params(),
8193
)
8294

8395
async def _get_provider_model_id(self, model: str) -> str:
@@ -124,10 +136,15 @@ async def openai_completion(
124136
"""
125137
Direct OpenAI completion API call.
126138
"""
127-
if guided_choice is not None:
128-
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
129-
if prompt_logprobs is not None:
130-
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
139+
# Handle parameters that are not supported by OpenAI API, but may be by the provider
140+
# prompt_logprobs is supported by vLLM
141+
# guided_choice is supported by vLLM
142+
# TODO: test coverage
143+
extra_body: dict[str, Any] = {}
144+
if prompt_logprobs is not None and prompt_logprobs >= 0:
145+
extra_body["prompt_logprobs"] = prompt_logprobs
146+
if guided_choice:
147+
extra_body["guided_choice"] = guided_choice
131148

132149
# TODO: fix openai_completion to return type compatible with OpenAI's API response
133150
return await self.client.completions.create( # type: ignore[no-any-return]
@@ -150,7 +167,8 @@ async def openai_completion(
150167
top_p=top_p,
151168
user=user,
152169
suffix=suffix,
153-
)
170+
),
171+
extra_body=extra_body,
154172
)
155173

156174
async def openai_chat_completion(

0 commit comments

Comments
 (0)