Skip to content

Commit 864b926

Browse files
authored
add tests for Ollama (#182)
1 parent bdf41b6 commit 864b926

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ jobs:
8383
with:
8484
enable-cache: true
8585

86+
- uses: pydantic/ollama-action@main
87+
with:
88+
model: qwen:0.5b
89+
8690
- run: >
8791
uv run
8892
--python 3.12

pydantic_ai_slim/pydantic_ai/models/ollama.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,13 @@ def __init__(
9191
self.model_name = model_name
9292
if openai_client is not None:
9393
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
94-
self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client, http_client=http_client)
95-
elif http_client is not None:
96-
# API key is not required for ollama but a value is required to create the client
97-
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client)
98-
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client, http_client=http_client)
94+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
95+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=openai_client)
9996
else:
10097
# API key is not required for ollama but a value is required to create the client
101-
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=cached_async_http_client())
102-
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client, http_client=http_client)
98+
http_client_ = http_client or cached_async_http_client()
99+
oai_client = AsyncOpenAI(base_url=base_url, api_key='ollama', http_client=http_client_)
100+
self.openai_model = OpenAIModel(model_name=model_name, openai_client=oai_client)
103101

104102
async def agent_model(
105103
self,

tests/test_live.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,18 @@ def groq(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
4848
return GroqModel('llama-3.1-70b-versatile', http_client=http_client)
4949

5050

51+
def ollama(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
52+
from pydantic_ai.models.ollama import OllamaModel
53+
54+
return OllamaModel('qwen:0.5b', http_client=http_client)
55+
56+
5157
params = [
5258
pytest.param(openai, id='openai'),
5359
pytest.param(gemini, id='gemini'),
5460
pytest.param(vertexai, id='vertexai'),
5561
pytest.param(groq, id='groq'),
62+
pytest.param(ollama, id='ollama'),
5663
]
5764
GetModel = Callable[[httpx.AsyncClient, Path], Model]
5865

@@ -83,14 +90,18 @@ async def test_stream(http_client: httpx.AsyncClient, tmp_path: Path, get_model:
8390
assert 'paris' in data.lower()
8491
print('Stream cost:', result.cost())
8592
cost = result.cost()
86-
assert cost.total_tokens is not None and cost.total_tokens > 0
93+
if get_model.__name__ != 'ollama':
94+
assert cost.total_tokens is not None and cost.total_tokens > 0
8795

8896

8997
class MyModel(BaseModel):
9098
city: str
9199

92100

93-
@pytest.mark.parametrize('get_model', params)
101+
structured_params = [p for p in params if p.id != 'ollama']
102+
103+
104+
@pytest.mark.parametrize('get_model', structured_params)
94105
async def test_structured(http_client: httpx.AsyncClient, tmp_path: Path, get_model: GetModel):
95106
agent = Agent(get_model(http_client, tmp_path), result_type=MyModel)
96107
result = await agent.run('What is the capital of the UK?')

0 commit comments

Comments
 (0)