|
| 1 | +from __future__ import annotations as _annotations |
| 2 | + |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Literal, Union |
| 5 | + |
| 6 | +from httpx import AsyncClient as AsyncHTTPClient |
| 7 | + |
| 8 | +from ..tools import ToolDefinition |
| 9 | +from . import ( |
| 10 | + AgentModel, |
| 11 | + Model, |
| 12 | + cached_async_http_client, |
| 13 | +) |
| 14 | + |
| 15 | +try: |
| 16 | + from openai import AsyncOpenAI |
| 17 | +except ImportError as e: |
| 18 | + raise ImportError( |
| 19 | + 'Please install `openai` to use the OpenAI model, ' |
| 20 | + "you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`" |
| 21 | + ) from e |
| 22 | + |
| 23 | + |
| 24 | +from .openai import OpenAIModel |
| 25 | + |
| 26 | +CommonOllamaModelNames = Literal[ |
| 27 | + 'codellama', |
| 28 | + 'gemma', |
| 29 | + 'gemma2', |
| 30 | + 'llama3', |
| 31 | + 'llama3.1', |
| 32 | + 'llama3.2', |
| 33 | + 'llama3.2-vision', |
| 34 | + 'llama3.3', |
| 35 | + 'mistral', |
| 36 | + 'mistral-nemo', |
| 37 | + 'mixtral', |
| 38 | + 'phi3', |
| 39 | + 'qwq', |
| 40 | + 'qwen', |
| 41 | + 'qwen2', |
| 42 | + 'qwen2.5', |
| 43 | + 'starcoder2', |
| 44 | +] |
| 45 | +"""This contains just the most common ollama models. |
| 46 | +
|
| 47 | +For a full list see [ollama.com/library](https://ollama.com/library). |
| 48 | +""" |
| 49 | +OllamaModelName = Union[CommonOllamaModelNames, str] |
| 50 | +"""Possible ollama models. |
| 51 | +
|
| 52 | +Since Ollama supports hundreds of models, we explicitly list the most models but |
| 53 | +allow any name in the type hints. |
| 54 | +""" |
| 55 | + |
| 56 | + |
| 57 | +@dataclass(init=False) |
| 58 | +class OllamaModel(Model): |
| 59 | + """A model that implements Ollama using the OpenAI API. |
| 60 | +
|
| 61 | + Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the Ollama server. |
| 62 | +
|
| 63 | + Apart from `__init__`, all methods are private or match those of the base class. |
| 64 | + """ |
| 65 | + |
| 66 | + model_name: OllamaModelName |
| 67 | + openai_model: OpenAIModel |
| 68 | + |
| 69 | + def __init__( |
| 70 | + self, |
| 71 | + model_name: OllamaModelName, |
| 72 | + *, |
| 73 | + base_url: str | None = 'http://localhost:11434/v1/', |
| 74 | + openai_client: AsyncOpenAI | None = None, |
| 75 | + http_client: AsyncHTTPClient | None = None, |
| 76 | + ): |
| 77 | + """Initialize an Ollama model. |
| 78 | +
|
| 79 | + Ollama has built-in compatability for the OpenAI chat completions API ([source](https://ollama.com/blog/openai-compatibility)), so we reuse the |
| 80 | + [`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel] here. |
| 81 | +
|
| 82 | + Args: |
| 83 | + model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library) |
| 84 | + You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model |
| 85 | + base_url: The base url for the ollama requests. The default value is the ollama default |
| 86 | + openai_client: An existing |
| 87 | + [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage) |
| 88 | + client to use, if provided, `base_url` and `http_client` must be `None`. |
| 89 | + http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. |
| 90 | + """ |
| 91 | + self.model_name = model_name |
| 92 | + if openai_client is not None: |
| 93 | + assert base_url is None, 'Cannot provide both `openai_client` and `base_url`' |
| 94 | + self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client, http_client=http_client) |
| 95 | + elif http_client is not None: |
| 96 | + # API key is not required for ollama but a value is required to create the client |
| 97 | + oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client) |
| 98 | + self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client, http_client=http_client) |
| 99 | + else: |
| 100 | + # API key is not required for ollama but a value is required to create the client |
| 101 | + oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=cached_async_http_client()) |
| 102 | + self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client, http_client=http_client) |
| 103 | + |
| 104 | + async def agent_model( |
| 105 | + self, |
| 106 | + *, |
| 107 | + function_tools: list[ToolDefinition], |
| 108 | + allow_text_result: bool, |
| 109 | + result_tools: list[ToolDefinition], |
| 110 | + ) -> AgentModel: |
| 111 | + return await self.openai_model.agent_model( |
| 112 | + function_tools=function_tools, |
| 113 | + allow_text_result=allow_text_result, |
| 114 | + result_tools=result_tools, |
| 115 | + ) |
| 116 | + |
| 117 | + def name(self) -> str: |
| 118 | + return f'ollama:{self.model_name}' |
0 commit comments