77
88import asyncio
99import base64
10- import uuid
1110from collections .abc import AsyncGenerator , AsyncIterator
1211from typing import Any
1312
14- from ollama import AsyncClient # type: ignore[attr-defined]
15- from openai import AsyncOpenAI
13+ from ollama import AsyncClient as AsyncOllamaClient
1614
1715from llama_stack .apis .common .content_types import (
1816 ImageContentItem ,
3735 Message ,
3836 OpenAIChatCompletion ,
3937 OpenAIChatCompletionChunk ,
40- OpenAICompletion ,
41- OpenAIEmbeddingsResponse ,
42- OpenAIEmbeddingUsage ,
4338 OpenAIMessageParam ,
4439 OpenAIResponseFormatParam ,
4540 ResponseFormat ,
6459from llama_stack .providers .utils .inference .openai_compat import (
6560 OpenAICompatCompletionChoice ,
6661 OpenAICompatCompletionResponse ,
67- b64_encode_openai_embeddings_response ,
6862 get_sampling_options ,
6963 prepare_openai_completion_params ,
70- prepare_openai_embeddings_params ,
7164 process_chat_completion_response ,
7265 process_chat_completion_stream_response ,
7366 process_completion_response ,
7467 process_completion_stream_response ,
7568)
69+ from llama_stack .providers .utils .inference .openai_mixin import OpenAIMixin
7670from llama_stack .providers .utils .inference .prompt_adapter import (
7771 chat_completion_request_to_prompt ,
7872 completion_request_to_prompt ,
8983
9084
9185class OllamaInferenceAdapter (
86+ OpenAIMixin ,
9287 InferenceProvider ,
9388 ModelsProtocolPrivate ,
9489):
@@ -98,23 +93,21 @@ class OllamaInferenceAdapter(
9893 def __init__ (self , config : OllamaImplConfig ) -> None :
9994 self .register_helper = ModelRegistryHelper (MODEL_ENTRIES )
10095 self .config = config
101- self ._clients : dict [asyncio .AbstractEventLoop , AsyncClient ] = {}
102- self ._openai_client = None
96+ self ._clients : dict [asyncio .AbstractEventLoop , AsyncOllamaClient ] = {}
10397
10498 @property
105- def client (self ) -> AsyncClient :
99+ def ollama_client (self ) -> AsyncOllamaClient :
106100 # ollama client attaches itself to the current event loop (sadly?)
107101 loop = asyncio .get_running_loop ()
108102 if loop not in self ._clients :
109- self ._clients [loop ] = AsyncClient (host = self .config .url )
103+ self ._clients [loop ] = AsyncOllamaClient (host = self .config .url )
110104 return self ._clients [loop ]
111105
112- @property
113- def openai_client (self ) -> AsyncOpenAI :
114- if self ._openai_client is None :
115- url = self .config .url .rstrip ("/" )
116- self ._openai_client = AsyncOpenAI (base_url = f"{ url } /v1" , api_key = "ollama" )
117- return self ._openai_client
106+ def get_api_key (self ):
107+ return "NO_KEY"
108+
109+ def get_base_url (self ):
110+ return self .config .url .rstrip ("/" ) + "/v1"
118111
119112 async def initialize (self ) -> None :
120113 logger .info (f"checking connectivity to Ollama at `{ self .config .url } `..." )
@@ -129,7 +122,7 @@ async def should_refresh_models(self) -> bool:
129122
130123 async def list_models (self ) -> list [Model ] | None :
131124 provider_id = self .__provider_id__
132- response = await self .client .list ()
125+ response = await self .ollama_client .list ()
133126
134127 # always add the two embedding models which can be pulled on demand
135128 models = [
@@ -189,7 +182,7 @@ async def health(self) -> HealthResponse:
189182 HealthResponse: A dictionary containing the health status.
190183 """
191184 try :
192- await self .client .ps ()
185+ await self .ollama_client .ps ()
193186 return HealthResponse (status = HealthStatus .OK )
194187 except Exception as e :
195188 return HealthResponse (status = HealthStatus .ERROR , message = f"Health check failed: { str (e )} " )
@@ -238,7 +231,7 @@ async def _stream_completion(
238231 params = await self ._get_params (request )
239232
240233 async def _generate_and_convert_to_openai_compat ():
241- s = await self .client .generate (** params )
234+ s = await self .ollama_client .generate (** params )
242235 async for chunk in s :
243236 choice = OpenAICompatCompletionChoice (
244237 finish_reason = chunk ["done_reason" ] if chunk ["done" ] else None ,
@@ -254,7 +247,7 @@ async def _generate_and_convert_to_openai_compat():
254247
255248 async def _nonstream_completion (self , request : CompletionRequest ) -> CompletionResponse :
256249 params = await self ._get_params (request )
257- r = await self .client .generate (** params )
250+ r = await self .ollama_client .generate (** params )
258251
259252 choice = OpenAICompatCompletionChoice (
260253 finish_reason = r ["done_reason" ] if r ["done" ] else None ,
@@ -346,9 +339,9 @@ async def _get_params(self, request: ChatCompletionRequest | CompletionRequest)
346339 async def _nonstream_chat_completion (self , request : ChatCompletionRequest ) -> ChatCompletionResponse :
347340 params = await self ._get_params (request )
348341 if "messages" in params :
349- r = await self .client .chat (** params )
342+ r = await self .ollama_client .chat (** params )
350343 else :
351- r = await self .client .generate (** params )
344+ r = await self .ollama_client .generate (** params )
352345
353346 if "message" in r :
354347 choice = OpenAICompatCompletionChoice (
@@ -372,9 +365,9 @@ async def _stream_chat_completion(
372365
373366 async def _generate_and_convert_to_openai_compat ():
374367 if "messages" in params :
375- s = await self .client .chat (** params )
368+ s = await self .ollama_client .chat (** params )
376369 else :
377- s = await self .client .generate (** params )
370+ s = await self .ollama_client .generate (** params )
378371 async for chunk in s :
379372 if "message" in chunk :
380373 choice = OpenAICompatCompletionChoice (
@@ -407,7 +400,7 @@ async def embeddings(
407400 assert all (not content_has_media (content ) for content in contents ), (
408401 "Ollama does not support media for embeddings"
409402 )
410- response = await self .client .embed (
403+ response = await self .ollama_client .embed (
411404 model = model .provider_resource_id ,
412405 input = [interleaved_content_as_str (content ) for content in contents ],
413406 )
@@ -422,14 +415,14 @@ async def register_model(self, model: Model) -> Model:
422415 pass # Ignore statically unknown model, will check live listing
423416
424417 if model .model_type == ModelType .embedding :
425- response = await self .client .list ()
418+ response = await self .ollama_client .list ()
426419 if model .provider_resource_id not in [m .model for m in response .models ]:
427- await self .client .pull (model .provider_resource_id )
420+ await self .ollama_client .pull (model .provider_resource_id )
428421
429422 # we use list() here instead of ps() -
430423 # - ps() only lists running models, not available models
431424 # - models not currently running are run by the ollama server as needed
432- response = await self .client .list ()
425+ response = await self .ollama_client .list ()
433426 available_models = [m .model for m in response .models ]
434427
435428 provider_resource_id = model .provider_resource_id
@@ -448,90 +441,6 @@ async def register_model(self, model: Model) -> Model:
448441
449442 return model
450443
451- async def openai_embeddings (
452- self ,
453- model : str ,
454- input : str | list [str ],
455- encoding_format : str | None = "float" ,
456- dimensions : int | None = None ,
457- user : str | None = None ,
458- ) -> OpenAIEmbeddingsResponse :
459- model_obj = await self ._get_model (model )
460- if model_obj .provider_resource_id is None :
461- raise ValueError (f"Model { model } has no provider_resource_id set" )
462-
463- # Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
464- params = prepare_openai_embeddings_params (
465- model = model_obj .provider_resource_id ,
466- input = input ,
467- encoding_format = encoding_format ,
468- dimensions = dimensions ,
469- user = user ,
470- )
471-
472- response = await self .openai_client .embeddings .create (** params )
473- data = b64_encode_openai_embeddings_response (response .data , encoding_format )
474-
475- usage = OpenAIEmbeddingUsage (
476- prompt_tokens = response .usage .prompt_tokens ,
477- total_tokens = response .usage .total_tokens ,
478- )
479- # TODO: Investigate why model_obj.identifier is used instead of response.model
480- return OpenAIEmbeddingsResponse (
481- data = data ,
482- model = model_obj .identifier ,
483- usage = usage ,
484- )
485-
486- async def openai_completion (
487- self ,
488- model : str ,
489- prompt : str | list [str ] | list [int ] | list [list [int ]],
490- best_of : int | None = None ,
491- echo : bool | None = None ,
492- frequency_penalty : float | None = None ,
493- logit_bias : dict [str , float ] | None = None ,
494- logprobs : bool | None = None ,
495- max_tokens : int | None = None ,
496- n : int | None = None ,
497- presence_penalty : float | None = None ,
498- seed : int | None = None ,
499- stop : str | list [str ] | None = None ,
500- stream : bool | None = None ,
501- stream_options : dict [str , Any ] | None = None ,
502- temperature : float | None = None ,
503- top_p : float | None = None ,
504- user : str | None = None ,
505- guided_choice : list [str ] | None = None ,
506- prompt_logprobs : int | None = None ,
507- suffix : str | None = None ,
508- ) -> OpenAICompletion :
509- if not isinstance (prompt , str ):
510- raise ValueError ("Ollama does not support non-string prompts for completion" )
511-
512- model_obj = await self ._get_model (model )
513- params = await prepare_openai_completion_params (
514- model = model_obj .provider_resource_id ,
515- prompt = prompt ,
516- best_of = best_of ,
517- echo = echo ,
518- frequency_penalty = frequency_penalty ,
519- logit_bias = logit_bias ,
520- logprobs = logprobs ,
521- max_tokens = max_tokens ,
522- n = n ,
523- presence_penalty = presence_penalty ,
524- seed = seed ,
525- stop = stop ,
526- stream = stream ,
527- stream_options = stream_options ,
528- temperature = temperature ,
529- top_p = top_p ,
530- user = user ,
531- suffix = suffix ,
532- )
533- return await self .openai_client .completions .create (** params ) # type: ignore
534-
535444 async def openai_chat_completion (
536445 self ,
537446 model : str ,
@@ -599,25 +508,7 @@ async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
599508 top_p = top_p ,
600509 user = user ,
601510 )
602- response = await self .openai_client .chat .completions .create (** params )
603- return await self ._adjust_ollama_chat_completion_response_ids (response )
604-
605- async def _adjust_ollama_chat_completion_response_ids (
606- self ,
607- response : OpenAIChatCompletion | AsyncIterator [OpenAIChatCompletionChunk ],
608- ) -> OpenAIChatCompletion | AsyncIterator [OpenAIChatCompletionChunk ]:
609- id = f"chatcmpl-{ uuid .uuid4 ()} "
610- if isinstance (response , AsyncIterator ):
611-
612- async def stream_with_chunk_ids () -> AsyncIterator [OpenAIChatCompletionChunk ]:
613- async for chunk in response :
614- chunk .id = id
615- yield chunk
616-
617- return stream_with_chunk_ids ()
618- else :
619- response .id = id
620- return response
511+ return await OpenAIMixin .openai_chat_completion (self , ** params )
621512
622513
623514async def convert_message_to_openai_dict_for_ollama (message : Message ) -> list [dict ]:
0 commit comments