Skip to content

Commit bdf41b6

Browse files
cal859CalCal McAuliffesamuelcolvinakgerber
authored
Ollama support (#162)
Co-authored-by: Cal <[email protected]> Co-authored-by: Cal McAuliffe <[email protected]> Co-authored-by: Samuel Colvin <[email protected]> Co-authored-by: Alan Gerber <[email protected]> Co-authored-by: Alex Hall <[email protected]>
1 parent bf5c295 commit bdf41b6

File tree

7 files changed

+259
-6
lines changed

7 files changed

+259
-6
lines changed

docs/api/models/ollama.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# `pydantic_ai.models.ollama`
2+
3+
## Setup
4+
5+
For details on how to set up authentication with this model, see [model configuration for Ollama](../../install.md#ollama).
6+
7+
## Example usage
8+
9+
With `ollama` installed, you can run the server with the model you want to use:
10+
11+
```bash title="terminal-run-ollama"
12+
ollama run llama3.2
13+
```
14+
(this will pull the `llama3.2` model if you don't already have it downloaded)
15+
16+
Then run your code, here's a minimal example:
17+
18+
```py title="ollama_example.py"
19+
from pydantic import BaseModel
20+
21+
from pydantic_ai import Agent
22+
23+
24+
class CityLocation(BaseModel):
25+
city: str
26+
country: str
27+
28+
29+
agent = Agent('ollama:llama3.2', result_type=CityLocation)
30+
31+
result = agent.run_sync('Where the olympics held in 2012?')
32+
print(result.data)
33+
#> city='London' country='United Kingdom'
34+
print(result.cost())
35+
#> Cost(request_tokens=56, response_tokens=8, total_tokens=64, details=None)
36+
```
37+
38+
See [`OllamaModel`][pydantic_ai.models.ollama.OllamaModel] for more information
39+
40+
::: pydantic_ai.models.ollama

docs/install.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,9 @@ model = GroqModel('llama-3.1-70b-versatile', api_key='your-api-key')
323323
agent = Agent(model)
324324
...
325325
```
326+
327+
### Ollama
328+
329+
To use [Ollama](https://ollama.com/), you must first download the Ollama client, and then download a model.
330+
331+
You must also ensure the Ollama server is running when trying to make requests to it. For more information, please see the [Ollama documentation](https://github.com/ollama/ollama/tree/main/docs)

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ nav:
3838
- api/exceptions.md
3939
- api/models/base.md
4040
- api/models/openai.md
41+
- api/models/ollama.md
4142
- api/models/gemini.md
4243
- api/models/vertexai.md
4344
- api/models/groq.md

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,23 @@
4949
'gemini-1.5-pro',
5050
'vertexai:gemini-1.5-flash',
5151
'vertexai:gemini-1.5-pro',
52+
'ollama:codellama',
53+
'ollama:gemma',
54+
'ollama:gemma2',
55+
'ollama:llama3',
56+
'ollama:llama3.1',
57+
'ollama:llama3.2',
58+
'ollama:llama3.2-vision',
59+
'ollama:llama3.3',
60+
'ollama:mistral',
61+
'ollama:mistral-nemo',
62+
'ollama:mixtral',
63+
'ollama:phi3',
64+
'ollama:qwq',
65+
'ollama:qwen',
66+
'ollama:qwen2',
67+
'ollama:qwen2.5',
68+
'ollama:starcoder2',
5269
'test',
5370
]
5471
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -239,7 +256,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
239256
elif model.startswith('openai:'):
240257
from .openai import OpenAIModel
241258

242-
return OpenAIModel(model[7:]) # pyright: ignore[reportArgumentType]
259+
return OpenAIModel(model[7:])
243260
elif model.startswith('gemini'):
244261
from .gemini import GeminiModel
245262

@@ -253,6 +270,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
253270
from .vertexai import VertexAIModel
254271

255272
return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
273+
elif model.startswith('ollama:'):
274+
from .ollama import OllamaModel
275+
276+
return OllamaModel(model[7:])
256277
else:
257278
raise UserError(f'Unknown model: {model}')
258279

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from __future__ import annotations as _annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Literal, Union
5+
6+
from httpx import AsyncClient as AsyncHTTPClient
7+
8+
from ..tools import ToolDefinition
9+
from . import (
10+
AgentModel,
11+
Model,
12+
cached_async_http_client,
13+
)
14+
15+
try:
16+
from openai import AsyncOpenAI
17+
except ImportError as e:
18+
raise ImportError(
19+
'Please install `openai` to use the OpenAI model, '
20+
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
21+
) from e
22+
23+
24+
from .openai import OpenAIModel
25+
26+
CommonOllamaModelNames = Literal[
27+
'codellama',
28+
'gemma',
29+
'gemma2',
30+
'llama3',
31+
'llama3.1',
32+
'llama3.2',
33+
'llama3.2-vision',
34+
'llama3.3',
35+
'mistral',
36+
'mistral-nemo',
37+
'mixtral',
38+
'phi3',
39+
'qwq',
40+
'qwen',
41+
'qwen2',
42+
'qwen2.5',
43+
'starcoder2',
44+
]
45+
"""This contains just the most common ollama models.
46+
47+
For a full list see [ollama.com/library](https://ollama.com/library).
48+
"""
49+
OllamaModelName = Union[CommonOllamaModelNames, str]
50+
"""Possible ollama models.
51+
52+
Since Ollama supports hundreds of models, we explicitly list the most models but
53+
allow any name in the type hints.
54+
"""
55+
56+
57+
@dataclass(init=False)
58+
class OllamaModel(Model):
59+
"""A model that implements Ollama using the OpenAI API.
60+
61+
Internally, this uses the [OpenAI Python client](https://github.com/openai/openai-python) to interact with the Ollama server.
62+
63+
Apart from `__init__`, all methods are private or match those of the base class.
64+
"""
65+
66+
model_name: OllamaModelName
67+
openai_model: OpenAIModel
68+
69+
def __init__(
70+
self,
71+
model_name: OllamaModelName,
72+
*,
73+
base_url: str | None = 'http://localhost:11434/v1/',
74+
openai_client: AsyncOpenAI | None = None,
75+
http_client: AsyncHTTPClient | None = None,
76+
):
77+
"""Initialize an Ollama model.
78+
79+
Ollama has built-in compatability for the OpenAI chat completions API ([source](https://ollama.com/blog/openai-compatibility)), so we reuse the
80+
[`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel] here.
81+
82+
Args:
83+
model_name: The name of the Ollama model to use. List of models available [here](https://ollama.com/library)
84+
You must first download the model (`ollama pull <MODEL-NAME>`) in order to use the model
85+
base_url: The base url for the ollama requests. The default value is the ollama default
86+
openai_client: An existing
87+
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
88+
client to use, if provided, `base_url` and `http_client` must be `None`.
89+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
90+
"""
91+
self.model_name = model_name
92+
if openai_client is not None:
93+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
94+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client, http_client=http_client)
95+
elif http_client is not None:
96+
# API key is not required for ollama but a value is required to create the client
97+
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client)
98+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client, http_client=http_client)
99+
else:
100+
# API key is not required for ollama but a value is required to create the client
101+
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=cached_async_http_client())
102+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client, http_client=http_client)
103+
104+
async def agent_model(
105+
self,
106+
*,
107+
function_tools: list[ToolDefinition],
108+
allow_text_result: bool,
109+
result_tools: list[ToolDefinition],
110+
) -> AgentModel:
111+
return await self.openai_model.agent_model(
112+
function_tools=function_tools,
113+
allow_text_result=allow_text_result,
114+
result_tools=result_tools,
115+
)
116+
117+
def name(self) -> str:
118+
return f'ollama:{self.model_name}'

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from contextlib import asynccontextmanager
55
from dataclasses import dataclass, field
66
from datetime import datetime, timezone
7-
from typing import Literal, overload
7+
from typing import Literal, Union, overload
88

99
from httpx import AsyncClient as AsyncHTTPClient
1010
from typing_extensions import assert_never
@@ -43,6 +43,12 @@
4343
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
4444
) from _import_error
4545

46+
OpenAIModelName = Union[ChatModel, str]
47+
"""
48+
Using this more broad type for the model name instead of the ChatModel definition
49+
allows this model to be used more easily with other model types (ie, Ollama)
50+
"""
51+
4652

4753
@dataclass(init=False)
4854
class OpenAIModel(Model):
@@ -53,12 +59,12 @@ class OpenAIModel(Model):
5359
Apart from `__init__`, all methods are private or match those of the base class.
5460
"""
5561

56-
model_name: ChatModel
62+
model_name: OpenAIModelName
5763
client: AsyncOpenAI = field(repr=False)
5864

5965
def __init__(
6066
self,
61-
model_name: ChatModel,
67+
model_name: OpenAIModelName,
6268
*,
6369
api_key: str | None = None,
6470
openai_client: AsyncOpenAI | None = None,
@@ -77,7 +83,7 @@ def __init__(
7783
client to use, if provided, `api_key` and `http_client` must be `None`.
7884
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
7985
"""
80-
self.model_name: ChatModel = model_name
86+
self.model_name: OpenAIModelName = model_name
8187
if openai_client is not None:
8288
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
8389
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
@@ -125,7 +131,7 @@ class OpenAIAgentModel(AgentModel):
125131
"""Implementation of `AgentModel` for OpenAI models."""
126132

127133
client: AsyncOpenAI
128-
model_name: ChatModel
134+
model_name: OpenAIModelName
129135
allow_text_result: bool
130136
tools: list[chat.ChatCompletionToolParam]
131137

tests/models/test_ollama.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations as _annotations
2+
3+
from datetime import datetime, timezone
4+
5+
import pytest
6+
from inline_snapshot import snapshot
7+
8+
from pydantic_ai import Agent
9+
from pydantic_ai.messages import (
10+
ModelTextResponse,
11+
UserPrompt,
12+
)
13+
from pydantic_ai.result import Cost
14+
15+
from ..conftest import IsNow, try_import
16+
17+
with try_import() as imports_successful:
18+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
19+
20+
from pydantic_ai.models.ollama import OllamaModel
21+
22+
from .test_openai import MockOpenAI, completion_message
23+
24+
pytestmark = [
25+
pytest.mark.skipif(not imports_successful(), reason='openai not installed'),
26+
pytest.mark.anyio,
27+
]
28+
29+
30+
def test_init():
31+
m = OllamaModel('llama3.2', base_url='foobar/')
32+
assert m.openai_model.client.api_key == 'ollama'
33+
assert m.openai_model.client.base_url == 'foobar/'
34+
assert m.name() == 'ollama:llama3.2'
35+
36+
37+
async def test_request_simple_success(allow_model_requests: None):
38+
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
39+
mock_client = MockOpenAI.create_mock(c)
40+
print('here')
41+
m = OllamaModel('llama3.2', openai_client=mock_client, base_url=None)
42+
agent = Agent(m)
43+
44+
result = await agent.run('hello')
45+
assert result.data == 'world'
46+
assert result.cost() == snapshot(Cost())
47+
48+
# reset the index so we get the same response again
49+
mock_client.index = 0 # type: ignore
50+
51+
result = await agent.run('hello', message_history=result.new_messages())
52+
assert result.data == 'world'
53+
assert result.cost() == snapshot(Cost())
54+
assert result.all_messages() == snapshot(
55+
[
56+
UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
57+
ModelTextResponse(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
58+
UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
59+
ModelTextResponse(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
60+
]
61+
)

0 commit comments

Comments
 (0)