44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
66import json
7- from collections .abc import AsyncGenerator , AsyncIterator
7+ from collections .abc import AsyncGenerator
88from typing import Any
99
1010import httpx
3838 LogProbConfig ,
3939 Message ,
4040 ModelStore ,
41- OpenAIChatCompletion ,
42- OpenAICompletion ,
43- OpenAIEmbeddingData ,
44- OpenAIEmbeddingsResponse ,
45- OpenAIEmbeddingUsage ,
46- OpenAIMessageParam ,
47- OpenAIResponseFormatParam ,
4841 ResponseFormat ,
4942 SamplingParams ,
5043 TextTruncation ,
7164 convert_message_to_openai_dict ,
7265 convert_tool_call ,
7366 get_sampling_options ,
74- prepare_openai_completion_params ,
7567 process_chat_completion_stream_response ,
7668 process_completion_response ,
7769 process_completion_stream_response ,
7870)
71+ from llama_stack .providers .utils .inference .openai_mixin import OpenAIMixin
7972from llama_stack .providers .utils .inference .prompt_adapter import (
8073 completion_request_to_prompt ,
8174 content_has_media ,
@@ -288,15 +281,14 @@ async def _process_vllm_chat_completion_stream_response(
288281 yield c
289282
290283
291- class VLLMInferenceAdapter (Inference , ModelsProtocolPrivate ):
284+ class VLLMInferenceAdapter (OpenAIMixin , Inference , ModelsProtocolPrivate ):
292285 # automatically set by the resolver when instantiating the provider
293286 __provider_id__ : str
294287 model_store : ModelStore | None = None
295288
296289 def __init__ (self , config : VLLMInferenceAdapterConfig ) -> None :
297290 self .register_helper = ModelRegistryHelper (build_hf_repo_model_entries ())
298291 self .config = config
299- self .client = None
300292
301293 async def initialize (self ) -> None :
302294 if not self .config .url :
@@ -308,8 +300,6 @@ async def should_refresh_models(self) -> bool:
308300 return self .config .refresh_models
309301
310302 async def list_models (self ) -> list [Model ] | None :
311- self ._lazy_initialize_client ()
312- assert self .client is not None # mypy
313303 models = []
314304 async for m in self .client .models .list ():
315305 model_type = ModelType .llm # unclear how to determine embedding vs. llm models
@@ -340,8 +330,7 @@ async def health(self) -> HealthResponse:
340330 HealthResponse: A dictionary containing the health status.
341331 """
342332 try :
343- client = self ._create_client () if self .client is None else self .client
344- _ = [m async for m in client .models .list ()] # Ensure the client is initialized
333+ _ = [m async for m in self .client .models .list ()] # Ensure the client is initialized
345334 return HealthResponse (status = HealthStatus .OK )
346335 except Exception as e :
347336 return HealthResponse (status = HealthStatus .ERROR , message = f"Health check failed: { str (e )} " )
@@ -351,19 +340,14 @@ async def _get_model(self, model_id: str) -> Model:
351340 raise ValueError ("Model store not set" )
352341 return await self .model_store .get_model (model_id )
353342
354- def _lazy_initialize_client (self ):
355- if self .client is not None :
356- return
343+ def get_api_key (self ):
344+ return self .config .api_token
357345
358- log . info ( f"Initializing vLLM client with base_url= { self . config . url } " )
359- self . client = self ._create_client ()
346+ def get_base_url ( self ):
347+ return self .config . url
360348
361- def _create_client (self ):
362- return AsyncOpenAI (
363- base_url = self .config .url ,
364- api_key = self .config .api_token ,
365- http_client = httpx .AsyncClient (verify = self .config .tls_verify ),
366- )
349+ def get_extra_client_params (self ):
350+ return {"http_client" : httpx .AsyncClient (verify = self .config .tls_verify )}
367351
368352 async def completion (
369353 self ,
@@ -374,7 +358,6 @@ async def completion(
374358 stream : bool | None = False ,
375359 logprobs : LogProbConfig | None = None ,
376360 ) -> CompletionResponse | AsyncGenerator [CompletionResponseStreamChunk , None ]:
377- self ._lazy_initialize_client ()
378361 if sampling_params is None :
379362 sampling_params = SamplingParams ()
380363 model = await self ._get_model (model_id )
@@ -406,7 +389,6 @@ async def chat_completion(
406389 logprobs : LogProbConfig | None = None ,
407390 tool_config : ToolConfig | None = None ,
408391 ) -> ChatCompletionResponse | AsyncGenerator [ChatCompletionResponseStreamChunk , None ]:
409- self ._lazy_initialize_client ()
410392 if sampling_params is None :
411393 sampling_params = SamplingParams ()
412394 model = await self ._get_model (model_id )
@@ -479,16 +461,12 @@ async def _stream_completion(
479461 yield chunk
480462
481463 async def register_model (self , model : Model ) -> Model :
482- # register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
483- # self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
484- # Changing this may lead to unpredictable behavior.
485- client = self ._create_client () if self .client is None else self .client
486464 try :
487465 model = await self .register_helper .register_model (model )
488466 except ValueError :
489467 pass # Ignore statically unknown model, will check live listing
490468 try :
491- res = await client .models .list ()
469+ res = await self . client .models .list ()
492470 except APIConnectionError as e :
493471 raise ValueError (
494472 f"Failed to connect to vLLM at { self .config .url } . Please check if vLLM is running and accessible at that URL."
@@ -543,8 +521,6 @@ async def embeddings(
543521 output_dimension : int | None = None ,
544522 task_type : EmbeddingTaskType | None = None ,
545523 ) -> EmbeddingsResponse :
546- self ._lazy_initialize_client ()
547- assert self .client is not None
548524 model = await self ._get_model (model_id )
549525
550526 kwargs = {}
@@ -560,154 +536,3 @@ async def embeddings(
560536
561537 embeddings = [data .embedding for data in response .data ]
562538 return EmbeddingsResponse (embeddings = embeddings )
563-
564- async def openai_embeddings (
565- self ,
566- model : str ,
567- input : str | list [str ],
568- encoding_format : str | None = "float" ,
569- dimensions : int | None = None ,
570- user : str | None = None ,
571- ) -> OpenAIEmbeddingsResponse :
572- self ._lazy_initialize_client ()
573- assert self .client is not None
574- model_obj = await self ._get_model (model )
575- assert model_obj .model_type == ModelType .embedding
576-
577- # Convert input to list if it's a string
578- input_list = [input ] if isinstance (input , str ) else input
579-
580- # Call vLLM embeddings endpoint with encoding_format
581- response = await self .client .embeddings .create (
582- model = model_obj .provider_resource_id ,
583- input = input_list ,
584- dimensions = dimensions ,
585- encoding_format = encoding_format ,
586- )
587-
588- # Convert response to OpenAI format
589- data = [
590- OpenAIEmbeddingData (
591- embedding = embedding_data .embedding ,
592- index = i ,
593- )
594- for i , embedding_data in enumerate (response .data )
595- ]
596-
597- # Not returning actual token usage since vLLM doesn't provide it
598- usage = OpenAIEmbeddingUsage (prompt_tokens = - 1 , total_tokens = - 1 )
599-
600- return OpenAIEmbeddingsResponse (
601- data = data ,
602- model = model_obj .provider_resource_id ,
603- usage = usage ,
604- )
605-
606- async def openai_completion (
607- self ,
608- model : str ,
609- prompt : str | list [str ] | list [int ] | list [list [int ]],
610- best_of : int | None = None ,
611- echo : bool | None = None ,
612- frequency_penalty : float | None = None ,
613- logit_bias : dict [str , float ] | None = None ,
614- logprobs : bool | None = None ,
615- max_tokens : int | None = None ,
616- n : int | None = None ,
617- presence_penalty : float | None = None ,
618- seed : int | None = None ,
619- stop : str | list [str ] | None = None ,
620- stream : bool | None = None ,
621- stream_options : dict [str , Any ] | None = None ,
622- temperature : float | None = None ,
623- top_p : float | None = None ,
624- user : str | None = None ,
625- guided_choice : list [str ] | None = None ,
626- prompt_logprobs : int | None = None ,
627- suffix : str | None = None ,
628- ) -> OpenAICompletion :
629- self ._lazy_initialize_client ()
630- model_obj = await self ._get_model (model )
631-
632- extra_body : dict [str , Any ] = {}
633- if prompt_logprobs is not None and prompt_logprobs >= 0 :
634- extra_body ["prompt_logprobs" ] = prompt_logprobs
635- if guided_choice :
636- extra_body ["guided_choice" ] = guided_choice
637-
638- params = await prepare_openai_completion_params (
639- model = model_obj .provider_resource_id ,
640- prompt = prompt ,
641- best_of = best_of ,
642- echo = echo ,
643- frequency_penalty = frequency_penalty ,
644- logit_bias = logit_bias ,
645- logprobs = logprobs ,
646- max_tokens = max_tokens ,
647- n = n ,
648- presence_penalty = presence_penalty ,
649- seed = seed ,
650- stop = stop ,
651- stream = stream ,
652- stream_options = stream_options ,
653- temperature = temperature ,
654- top_p = top_p ,
655- user = user ,
656- extra_body = extra_body ,
657- )
658- return await self .client .completions .create (** params ) # type: ignore
659-
660- async def openai_chat_completion (
661- self ,
662- model : str ,
663- messages : list [OpenAIMessageParam ],
664- frequency_penalty : float | None = None ,
665- function_call : str | dict [str , Any ] | None = None ,
666- functions : list [dict [str , Any ]] | None = None ,
667- logit_bias : dict [str , float ] | None = None ,
668- logprobs : bool | None = None ,
669- max_completion_tokens : int | None = None ,
670- max_tokens : int | None = None ,
671- n : int | None = None ,
672- parallel_tool_calls : bool | None = None ,
673- presence_penalty : float | None = None ,
674- response_format : OpenAIResponseFormatParam | None = None ,
675- seed : int | None = None ,
676- stop : str | list [str ] | None = None ,
677- stream : bool | None = None ,
678- stream_options : dict [str , Any ] | None = None ,
679- temperature : float | None = None ,
680- tool_choice : str | dict [str , Any ] | None = None ,
681- tools : list [dict [str , Any ]] | None = None ,
682- top_logprobs : int | None = None ,
683- top_p : float | None = None ,
684- user : str | None = None ,
685- ) -> OpenAIChatCompletion | AsyncIterator [OpenAIChatCompletionChunk ]:
686- self ._lazy_initialize_client ()
687- model_obj = await self ._get_model (model )
688- params = await prepare_openai_completion_params (
689- model = model_obj .provider_resource_id ,
690- messages = messages ,
691- frequency_penalty = frequency_penalty ,
692- function_call = function_call ,
693- functions = functions ,
694- logit_bias = logit_bias ,
695- logprobs = logprobs ,
696- max_completion_tokens = max_completion_tokens ,
697- max_tokens = max_tokens ,
698- n = n ,
699- parallel_tool_calls = parallel_tool_calls ,
700- presence_penalty = presence_penalty ,
701- response_format = response_format ,
702- seed = seed ,
703- stop = stop ,
704- stream = stream ,
705- stream_options = stream_options ,
706- temperature = temperature ,
707- tool_choice = tool_choice ,
708- tools = tools ,
709- top_logprobs = top_logprobs ,
710- top_p = top_p ,
711- user = user ,
712- )
713- return await self .client .chat .completions .create (** params ) # type: ignore
0 commit comments