44# This source code is licensed under the terms described in the LICENSE file in 
55# the root directory of this source tree. 
66import  json 
7- from  collections .abc  import  AsyncGenerator ,  AsyncIterator 
7+ from  collections .abc  import  AsyncGenerator 
88from  typing  import  Any 
99
1010import  httpx 
3838    LogProbConfig ,
3939    Message ,
4040    ModelStore ,
41-     OpenAIChatCompletion ,
42-     OpenAICompletion ,
43-     OpenAIEmbeddingData ,
44-     OpenAIEmbeddingsResponse ,
45-     OpenAIEmbeddingUsage ,
46-     OpenAIMessageParam ,
47-     OpenAIResponseFormatParam ,
4841    ResponseFormat ,
4942    SamplingParams ,
5043    TextTruncation ,
7164    convert_message_to_openai_dict ,
7265    convert_tool_call ,
7366    get_sampling_options ,
74-     prepare_openai_completion_params ,
7567    process_chat_completion_stream_response ,
7668    process_completion_response ,
7769    process_completion_stream_response ,
7870)
71+ from  llama_stack .providers .utils .inference .openai_mixin  import  OpenAIMixin 
7972from  llama_stack .providers .utils .inference .prompt_adapter  import  (
8073    completion_request_to_prompt ,
8174    content_has_media ,
@@ -288,15 +281,14 @@ async def _process_vllm_chat_completion_stream_response(
288281        yield  c 
289282
290283
291- class  VLLMInferenceAdapter (Inference , ModelsProtocolPrivate ):
284+ class  VLLMInferenceAdapter (OpenAIMixin ,  Inference , ModelsProtocolPrivate ):
292285    # automatically set by the resolver when instantiating the provider 
293286    __provider_id__ : str 
294287    model_store : ModelStore  |  None  =  None 
295288
296289    def  __init__ (self , config : VLLMInferenceAdapterConfig ) ->  None :
297290        self .register_helper  =  ModelRegistryHelper (build_hf_repo_model_entries ())
298291        self .config  =  config 
299-         self .client  =  None 
300292
301293    async  def  initialize (self ) ->  None :
302294        if  not  self .config .url :
@@ -308,8 +300,6 @@ async def should_refresh_models(self) -> bool:
308300        return  self .config .refresh_models 
309301
310302    async  def  list_models (self ) ->  list [Model ] |  None :
311-         self ._lazy_initialize_client ()
312-         assert  self .client  is  not None   # mypy 
313303        models  =  []
314304        async  for  m  in  self .client .models .list ():
315305            model_type  =  ModelType .llm   # unclear how to determine embedding vs. llm models 
@@ -340,8 +330,7 @@ async def health(self) -> HealthResponse:
340330            HealthResponse: A dictionary containing the health status. 
341331        """ 
342332        try :
343-             client  =  self ._create_client () if  self .client  is  None  else  self .client 
344-             _  =  [m  async  for  m  in  client .models .list ()]  # Ensure the client is initialized 
333+             _  =  [m  async  for  m  in  self .client .models .list ()]  # Ensure the client is initialized 
345334            return  HealthResponse (status = HealthStatus .OK )
346335        except  Exception  as  e :
347336            return  HealthResponse (status = HealthStatus .ERROR , message = f"Health check failed: { str (e )}  )
@@ -351,19 +340,14 @@ async def _get_model(self, model_id: str) -> Model:
351340            raise  ValueError ("Model store not set" )
352341        return  await  self .model_store .get_model (model_id )
353342
354-     def  _lazy_initialize_client (self ):
355-         if  self .client  is  not None :
356-             return 
343+     def  get_api_key (self ):
344+         return  self .config .api_token 
357345
358-          log . info ( f"Initializing vLLM client with base_url= { self . config . url } " ) 
359-         self . client   =   self ._create_client () 
346+     def   get_base_url ( self ): 
347+         return   self .config . url 
360348
361-     def  _create_client (self ):
362-         return  AsyncOpenAI (
363-             base_url = self .config .url ,
364-             api_key = self .config .api_token ,
365-             http_client = httpx .AsyncClient (verify = self .config .tls_verify ),
366-         )
349+     def  get_extra_client_params (self ):
350+         return  {"http_client" : httpx .AsyncClient (verify = self .config .tls_verify )}
367351
368352    async  def  completion (
369353        self ,
@@ -374,7 +358,6 @@ async def completion(
374358        stream : bool  |  None  =  False ,
375359        logprobs : LogProbConfig  |  None  =  None ,
376360    ) ->  CompletionResponse  |  AsyncGenerator [CompletionResponseStreamChunk , None ]:
377-         self ._lazy_initialize_client ()
378361        if  sampling_params  is  None :
379362            sampling_params  =  SamplingParams ()
380363        model  =  await  self ._get_model (model_id )
@@ -406,7 +389,6 @@ async def chat_completion(
406389        logprobs : LogProbConfig  |  None  =  None ,
407390        tool_config : ToolConfig  |  None  =  None ,
408391    ) ->  ChatCompletionResponse  |  AsyncGenerator [ChatCompletionResponseStreamChunk , None ]:
409-         self ._lazy_initialize_client ()
410392        if  sampling_params  is  None :
411393            sampling_params  =  SamplingParams ()
412394        model  =  await  self ._get_model (model_id )
@@ -479,16 +461,12 @@ async def _stream_completion(
479461            yield  chunk 
480462
481463    async  def  register_model (self , model : Model ) ->  Model :
482-         # register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet. 
483-         # self.client should only be created after the initialization is complete to avoid asyncio cross-context errors. 
484-         # Changing this may lead to unpredictable behavior. 
485-         client  =  self ._create_client () if  self .client  is  None  else  self .client 
486464        try :
487465            model  =  await  self .register_helper .register_model (model )
488466        except  ValueError :
489467            pass   # Ignore statically unknown model, will check live listing 
490468        try :
491-             res  =  await  client .models .list ()
469+             res  =  await  self . client .models .list ()
492470        except  APIConnectionError  as  e :
493471            raise  ValueError (
494472                f"Failed to connect to vLLM at { self .config .url }  
@@ -543,8 +521,6 @@ async def embeddings(
543521        output_dimension : int  |  None  =  None ,
544522        task_type : EmbeddingTaskType  |  None  =  None ,
545523    ) ->  EmbeddingsResponse :
546-         self ._lazy_initialize_client ()
547-         assert  self .client  is  not None 
548524        model  =  await  self ._get_model (model_id )
549525
550526        kwargs  =  {}
@@ -560,154 +536,3 @@ async def embeddings(
560536
561537        embeddings  =  [data .embedding  for  data  in  response .data ]
562538        return  EmbeddingsResponse (embeddings = embeddings )
563- 
564-     async  def  openai_embeddings (
565-         self ,
566-         model : str ,
567-         input : str  |  list [str ],
568-         encoding_format : str  |  None  =  "float" ,
569-         dimensions : int  |  None  =  None ,
570-         user : str  |  None  =  None ,
571-     ) ->  OpenAIEmbeddingsResponse :
572-         self ._lazy_initialize_client ()
573-         assert  self .client  is  not None 
574-         model_obj  =  await  self ._get_model (model )
575-         assert  model_obj .model_type  ==  ModelType .embedding 
576- 
577-         # Convert input to list if it's a string 
578-         input_list  =  [input ] if  isinstance (input , str ) else  input 
579- 
580-         # Call vLLM embeddings endpoint with encoding_format 
581-         response  =  await  self .client .embeddings .create (
582-             model = model_obj .provider_resource_id ,
583-             input = input_list ,
584-             dimensions = dimensions ,
585-             encoding_format = encoding_format ,
586-         )
587- 
588-         # Convert response to OpenAI format 
589-         data  =  [
590-             OpenAIEmbeddingData (
591-                 embedding = embedding_data .embedding ,
592-                 index = i ,
593-             )
594-             for  i , embedding_data  in  enumerate (response .data )
595-         ]
596- 
597-         # Not returning actual token usage since vLLM doesn't provide it 
598-         usage  =  OpenAIEmbeddingUsage (prompt_tokens = - 1 , total_tokens = - 1 )
599- 
600-         return  OpenAIEmbeddingsResponse (
601-             data = data ,
602-             model = model_obj .provider_resource_id ,
603-             usage = usage ,
604-         )
605- 
606-     async  def  openai_completion (
607-         self ,
608-         model : str ,
609-         prompt : str  |  list [str ] |  list [int ] |  list [list [int ]],
610-         best_of : int  |  None  =  None ,
611-         echo : bool  |  None  =  None ,
612-         frequency_penalty : float  |  None  =  None ,
613-         logit_bias : dict [str , float ] |  None  =  None ,
614-         logprobs : bool  |  None  =  None ,
615-         max_tokens : int  |  None  =  None ,
616-         n : int  |  None  =  None ,
617-         presence_penalty : float  |  None  =  None ,
618-         seed : int  |  None  =  None ,
619-         stop : str  |  list [str ] |  None  =  None ,
620-         stream : bool  |  None  =  None ,
621-         stream_options : dict [str , Any ] |  None  =  None ,
622-         temperature : float  |  None  =  None ,
623-         top_p : float  |  None  =  None ,
624-         user : str  |  None  =  None ,
625-         guided_choice : list [str ] |  None  =  None ,
626-         prompt_logprobs : int  |  None  =  None ,
627-         suffix : str  |  None  =  None ,
628-     ) ->  OpenAICompletion :
629-         self ._lazy_initialize_client ()
630-         model_obj  =  await  self ._get_model (model )
631- 
632-         extra_body : dict [str , Any ] =  {}
633-         if  prompt_logprobs  is  not None  and  prompt_logprobs  >=  0 :
634-             extra_body ["prompt_logprobs" ] =  prompt_logprobs 
635-         if  guided_choice :
636-             extra_body ["guided_choice" ] =  guided_choice 
637- 
638-         params  =  await  prepare_openai_completion_params (
639-             model = model_obj .provider_resource_id ,
640-             prompt = prompt ,
641-             best_of = best_of ,
642-             echo = echo ,
643-             frequency_penalty = frequency_penalty ,
644-             logit_bias = logit_bias ,
645-             logprobs = logprobs ,
646-             max_tokens = max_tokens ,
647-             n = n ,
648-             presence_penalty = presence_penalty ,
649-             seed = seed ,
650-             stop = stop ,
651-             stream = stream ,
652-             stream_options = stream_options ,
653-             temperature = temperature ,
654-             top_p = top_p ,
655-             user = user ,
656-             extra_body = extra_body ,
657-         )
658-         return  await  self .client .completions .create (** params )  # type: ignore 
659- 
660-     async  def  openai_chat_completion (
661-         self ,
662-         model : str ,
663-         messages : list [OpenAIMessageParam ],
664-         frequency_penalty : float  |  None  =  None ,
665-         function_call : str  |  dict [str , Any ] |  None  =  None ,
666-         functions : list [dict [str , Any ]] |  None  =  None ,
667-         logit_bias : dict [str , float ] |  None  =  None ,
668-         logprobs : bool  |  None  =  None ,
669-         max_completion_tokens : int  |  None  =  None ,
670-         max_tokens : int  |  None  =  None ,
671-         n : int  |  None  =  None ,
672-         parallel_tool_calls : bool  |  None  =  None ,
673-         presence_penalty : float  |  None  =  None ,
674-         response_format : OpenAIResponseFormatParam  |  None  =  None ,
675-         seed : int  |  None  =  None ,
676-         stop : str  |  list [str ] |  None  =  None ,
677-         stream : bool  |  None  =  None ,
678-         stream_options : dict [str , Any ] |  None  =  None ,
679-         temperature : float  |  None  =  None ,
680-         tool_choice : str  |  dict [str , Any ] |  None  =  None ,
681-         tools : list [dict [str , Any ]] |  None  =  None ,
682-         top_logprobs : int  |  None  =  None ,
683-         top_p : float  |  None  =  None ,
684-         user : str  |  None  =  None ,
685-     ) ->  OpenAIChatCompletion  |  AsyncIterator [OpenAIChatCompletionChunk ]:
686-         self ._lazy_initialize_client ()
687-         model_obj  =  await  self ._get_model (model )
688-         params  =  await  prepare_openai_completion_params (
689-             model = model_obj .provider_resource_id ,
690-             messages = messages ,
691-             frequency_penalty = frequency_penalty ,
692-             function_call = function_call ,
693-             functions = functions ,
694-             logit_bias = logit_bias ,
695-             logprobs = logprobs ,
696-             max_completion_tokens = max_completion_tokens ,
697-             max_tokens = max_tokens ,
698-             n = n ,
699-             parallel_tool_calls = parallel_tool_calls ,
700-             presence_penalty = presence_penalty ,
701-             response_format = response_format ,
702-             seed = seed ,
703-             stop = stop ,
704-             stream = stream ,
705-             stream_options = stream_options ,
706-             temperature = temperature ,
707-             tool_choice = tool_choice ,
708-             tools = tools ,
709-             top_logprobs = top_logprobs ,
710-             top_p = top_p ,
711-             user = user ,
712-         )
713-         return  await  self .client .chat .completions .create (** params )  # type: ignore 
0 commit comments