Skip to content

Commit 5a6f28c

Browse files
mattfiamemilio
authored andcommitted
chore: update the ollama inference impl to use OpenAIMixin for openai-compat functions (llamastack#3395)
# What does this PR do? update Ollama inference provider to use OpenAIMixin for openai-compat endpoints ## Test Plan ci
1 parent c2f2ca0 commit 5a6f28c

File tree

4 files changed

+1100
-133
lines changed

4 files changed

+1100
-133
lines changed

llama_stack/providers/remote/inference/ollama/ollama.py

Lines changed: 24 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77

88
import asyncio
99
import base64
10-
import uuid
1110
from collections.abc import AsyncGenerator, AsyncIterator
1211
from typing import Any
1312

14-
from ollama import AsyncClient # type: ignore[attr-defined]
15-
from openai import AsyncOpenAI
13+
from ollama import AsyncClient as AsyncOllamaClient
1614

1715
from llama_stack.apis.common.content_types import (
1816
ImageContentItem,
@@ -37,9 +35,6 @@
3735
Message,
3836
OpenAIChatCompletion,
3937
OpenAIChatCompletionChunk,
40-
OpenAICompletion,
41-
OpenAIEmbeddingsResponse,
42-
OpenAIEmbeddingUsage,
4338
OpenAIMessageParam,
4439
OpenAIResponseFormatParam,
4540
ResponseFormat,
@@ -64,15 +59,14 @@
6459
from llama_stack.providers.utils.inference.openai_compat import (
6560
OpenAICompatCompletionChoice,
6661
OpenAICompatCompletionResponse,
67-
b64_encode_openai_embeddings_response,
6862
get_sampling_options,
6963
prepare_openai_completion_params,
70-
prepare_openai_embeddings_params,
7164
process_chat_completion_response,
7265
process_chat_completion_stream_response,
7366
process_completion_response,
7467
process_completion_stream_response,
7568
)
69+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
7670
from llama_stack.providers.utils.inference.prompt_adapter import (
7771
chat_completion_request_to_prompt,
7872
completion_request_to_prompt,
@@ -89,6 +83,7 @@
8983

9084

9185
class OllamaInferenceAdapter(
86+
OpenAIMixin,
9287
InferenceProvider,
9388
ModelsProtocolPrivate,
9489
):
@@ -98,23 +93,21 @@ class OllamaInferenceAdapter(
9893
def __init__(self, config: OllamaImplConfig) -> None:
9994
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
10095
self.config = config
101-
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
102-
self._openai_client = None
96+
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
10397

10498
@property
105-
def client(self) -> AsyncClient:
99+
def ollama_client(self) -> AsyncOllamaClient:
106100
# ollama client attaches itself to the current event loop (sadly?)
107101
loop = asyncio.get_running_loop()
108102
if loop not in self._clients:
109-
self._clients[loop] = AsyncClient(host=self.config.url)
103+
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
110104
return self._clients[loop]
111105

112-
@property
113-
def openai_client(self) -> AsyncOpenAI:
114-
if self._openai_client is None:
115-
url = self.config.url.rstrip("/")
116-
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
117-
return self._openai_client
106+
def get_api_key(self):
107+
return "NO_KEY"
108+
109+
def get_base_url(self):
110+
return self.config.url.rstrip("/") + "/v1"
118111

119112
async def initialize(self) -> None:
120113
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
@@ -129,7 +122,7 @@ async def should_refresh_models(self) -> bool:
129122

130123
async def list_models(self) -> list[Model] | None:
131124
provider_id = self.__provider_id__
132-
response = await self.client.list()
125+
response = await self.ollama_client.list()
133126

134127
# always add the two embedding models which can be pulled on demand
135128
models = [
@@ -189,7 +182,7 @@ async def health(self) -> HealthResponse:
189182
HealthResponse: A dictionary containing the health status.
190183
"""
191184
try:
192-
await self.client.ps()
185+
await self.ollama_client.ps()
193186
return HealthResponse(status=HealthStatus.OK)
194187
except Exception as e:
195188
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@@ -238,7 +231,7 @@ async def _stream_completion(
238231
params = await self._get_params(request)
239232

240233
async def _generate_and_convert_to_openai_compat():
241-
s = await self.client.generate(**params)
234+
s = await self.ollama_client.generate(**params)
242235
async for chunk in s:
243236
choice = OpenAICompatCompletionChoice(
244237
finish_reason=chunk["done_reason"] if chunk["done"] else None,
@@ -254,7 +247,7 @@ async def _generate_and_convert_to_openai_compat():
254247

255248
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
256249
params = await self._get_params(request)
257-
r = await self.client.generate(**params)
250+
r = await self.ollama_client.generate(**params)
258251

259252
choice = OpenAICompatCompletionChoice(
260253
finish_reason=r["done_reason"] if r["done"] else None,
@@ -346,9 +339,9 @@ async def _get_params(self, request: ChatCompletionRequest | CompletionRequest)
346339
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
347340
params = await self._get_params(request)
348341
if "messages" in params:
349-
r = await self.client.chat(**params)
342+
r = await self.ollama_client.chat(**params)
350343
else:
351-
r = await self.client.generate(**params)
344+
r = await self.ollama_client.generate(**params)
352345

353346
if "message" in r:
354347
choice = OpenAICompatCompletionChoice(
@@ -372,9 +365,9 @@ async def _stream_chat_completion(
372365

373366
async def _generate_and_convert_to_openai_compat():
374367
if "messages" in params:
375-
s = await self.client.chat(**params)
368+
s = await self.ollama_client.chat(**params)
376369
else:
377-
s = await self.client.generate(**params)
370+
s = await self.ollama_client.generate(**params)
378371
async for chunk in s:
379372
if "message" in chunk:
380373
choice = OpenAICompatCompletionChoice(
@@ -407,7 +400,7 @@ async def embeddings(
407400
assert all(not content_has_media(content) for content in contents), (
408401
"Ollama does not support media for embeddings"
409402
)
410-
response = await self.client.embed(
403+
response = await self.ollama_client.embed(
411404
model=model.provider_resource_id,
412405
input=[interleaved_content_as_str(content) for content in contents],
413406
)
@@ -422,14 +415,14 @@ async def register_model(self, model: Model) -> Model:
422415
pass # Ignore statically unknown model, will check live listing
423416

424417
if model.model_type == ModelType.embedding:
425-
response = await self.client.list()
418+
response = await self.ollama_client.list()
426419
if model.provider_resource_id not in [m.model for m in response.models]:
427-
await self.client.pull(model.provider_resource_id)
420+
await self.ollama_client.pull(model.provider_resource_id)
428421

429422
# we use list() here instead of ps() -
430423
# - ps() only lists running models, not available models
431424
# - models not currently running are run by the ollama server as needed
432-
response = await self.client.list()
425+
response = await self.ollama_client.list()
433426
available_models = [m.model for m in response.models]
434427

435428
provider_resource_id = model.provider_resource_id
@@ -448,90 +441,6 @@ async def register_model(self, model: Model) -> Model:
448441

449442
return model
450443

451-
async def openai_embeddings(
452-
self,
453-
model: str,
454-
input: str | list[str],
455-
encoding_format: str | None = "float",
456-
dimensions: int | None = None,
457-
user: str | None = None,
458-
) -> OpenAIEmbeddingsResponse:
459-
model_obj = await self._get_model(model)
460-
if model_obj.provider_resource_id is None:
461-
raise ValueError(f"Model {model} has no provider_resource_id set")
462-
463-
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
464-
params = prepare_openai_embeddings_params(
465-
model=model_obj.provider_resource_id,
466-
input=input,
467-
encoding_format=encoding_format,
468-
dimensions=dimensions,
469-
user=user,
470-
)
471-
472-
response = await self.openai_client.embeddings.create(**params)
473-
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
474-
475-
usage = OpenAIEmbeddingUsage(
476-
prompt_tokens=response.usage.prompt_tokens,
477-
total_tokens=response.usage.total_tokens,
478-
)
479-
# TODO: Investigate why model_obj.identifier is used instead of response.model
480-
return OpenAIEmbeddingsResponse(
481-
data=data,
482-
model=model_obj.identifier,
483-
usage=usage,
484-
)
485-
486-
async def openai_completion(
487-
self,
488-
model: str,
489-
prompt: str | list[str] | list[int] | list[list[int]],
490-
best_of: int | None = None,
491-
echo: bool | None = None,
492-
frequency_penalty: float | None = None,
493-
logit_bias: dict[str, float] | None = None,
494-
logprobs: bool | None = None,
495-
max_tokens: int | None = None,
496-
n: int | None = None,
497-
presence_penalty: float | None = None,
498-
seed: int | None = None,
499-
stop: str | list[str] | None = None,
500-
stream: bool | None = None,
501-
stream_options: dict[str, Any] | None = None,
502-
temperature: float | None = None,
503-
top_p: float | None = None,
504-
user: str | None = None,
505-
guided_choice: list[str] | None = None,
506-
prompt_logprobs: int | None = None,
507-
suffix: str | None = None,
508-
) -> OpenAICompletion:
509-
if not isinstance(prompt, str):
510-
raise ValueError("Ollama does not support non-string prompts for completion")
511-
512-
model_obj = await self._get_model(model)
513-
params = await prepare_openai_completion_params(
514-
model=model_obj.provider_resource_id,
515-
prompt=prompt,
516-
best_of=best_of,
517-
echo=echo,
518-
frequency_penalty=frequency_penalty,
519-
logit_bias=logit_bias,
520-
logprobs=logprobs,
521-
max_tokens=max_tokens,
522-
n=n,
523-
presence_penalty=presence_penalty,
524-
seed=seed,
525-
stop=stop,
526-
stream=stream,
527-
stream_options=stream_options,
528-
temperature=temperature,
529-
top_p=top_p,
530-
user=user,
531-
suffix=suffix,
532-
)
533-
return await self.openai_client.completions.create(**params) # type: ignore
534-
535444
async def openai_chat_completion(
536445
self,
537446
model: str,
@@ -599,25 +508,7 @@ async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
599508
top_p=top_p,
600509
user=user,
601510
)
602-
response = await self.openai_client.chat.completions.create(**params)
603-
return await self._adjust_ollama_chat_completion_response_ids(response)
604-
605-
async def _adjust_ollama_chat_completion_response_ids(
606-
self,
607-
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
608-
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
609-
id = f"chatcmpl-{uuid.uuid4()}"
610-
if isinstance(response, AsyncIterator):
611-
612-
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
613-
async for chunk in response:
614-
chunk.id = id
615-
yield chunk
616-
617-
return stream_with_chunk_ids()
618-
else:
619-
response.id = id
620-
return response
511+
return await OpenAIMixin.openai_chat_completion(self, **params)
621512

622513

623514
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:

0 commit comments

Comments
 (0)