Skip to content

Commit d4a52ef

Browse files
authored
test against real models (#49)
1 parent 6dc3e8d commit d4a52ef

File tree

4 files changed

+114
-5
lines changed

4 files changed

+114
-5
lines changed

.github/workflows/ci.yml

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ jobs:
5656

5757
- run: tree site
5858

59+
test-live:
60+
runs-on: ubuntu-latest
61+
steps:
62+
- uses: actions/checkout@v4
63+
64+
- uses: astral-sh/setup-uv@v3
65+
with:
66+
enable-cache: true
67+
68+
- run: uv run --python 3.12 --frozen pytest tests/test_live.py -v --durations=100
69+
if: github.repository_owner == 'pydantic'
70+
env:
71+
PYDANTIC_AI_LIVE_TEST_DANGEROUS: 'CHARGE-ME!'
72+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
73+
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
74+
5975
test:
6076
name: test on ${{ matrix.python-version }}
6177
runs-on: ubuntu-latest
@@ -64,7 +80,7 @@ jobs:
6480
matrix:
6581
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
6682
env:
67-
PYTHON: ${{ matrix.python-version }}
83+
UV_PYTHON: ${{ matrix.python-version }}
6884
steps:
6985
- uses: actions/checkout@v4
7086

@@ -73,15 +89,15 @@ jobs:
7389
enable-cache: true
7490

7591
- run: mkdir coverage
76-
- run: uv run --frozen --python ${{ matrix.python-version }} coverage run -m pytest
92+
- run: uv run --frozen coverage run -m pytest
7793
env:
7894
COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}
7995

80-
- run: uv run --frozen --all-extras --python ${{ matrix.python-version }} coverage run -m pytest
96+
- run: uv run --frozen --all-extras coverage run -m pytest
8197
env:
8298
COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-all-extras
8399

84-
- run: uv run --frozen --all-extras --python ${{ matrix.python-version }} python tests/import_examples.py
100+
- run: uv run --frozen --all-extras python tests/import_examples.py
85101

86102
- name: store coverage files
87103
uses: actions/upload-artifact@v4
@@ -117,7 +133,7 @@ jobs:
117133
# https://github.com/marketplace/actions/alls-green#why used for branch protection checks
118134
check:
119135
if: always()
120-
needs: [lint, docs, test, coverage]
136+
needs: [lint, docs, test-live, test, coverage]
121137
runs-on: ubuntu-latest
122138

123139
steps:

pydantic_ai/models/gemini.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncI
187187

188188
async with self.http_client.stream('POST', url, content=request_json, headers=headers) as r:
189189
if r.status_code != 200:
190+
await r.aread()
190191
raise exceptions.UnexpectedModelBehaviour(f'Unexpected response from gemini {r.status_code}', r.text)
191192
yield r
192193

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ filterwarnings = [
146146
[tool.coverage.run]
147147
# required to avoid warnings about files created by create_module fixture
148148
include = ["pydantic_ai/**/*.py", "tests/**/*.py"]
149+
omit = ["tests/test_live.py"]
149150
branch = true
150151

151152
# https://coverage.readthedocs.io/en/latest/config.html#report

tests/test_live.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Tests of pydantic-ai actually connecting to OpenAI and Gemini models.
2+
3+
WARNING: running these tests will consume your OpenAI and Gemini credits.
4+
"""
5+
6+
import os
7+
8+
import httpx
9+
import pytest
10+
from pydantic import BaseModel
11+
12+
from pydantic_ai import Agent
13+
from pydantic_ai.models.gemini import GeminiModel
14+
from pydantic_ai.models.openai import OpenAIModel
15+
16+
pytestmark = [
17+
pytest.mark.skipif(os.getenv('PYDANTIC_AI_LIVE_TEST_DANGEROUS') != 'CHARGE-ME!', reason='live tests disabled'),
18+
pytest.mark.anyio,
19+
]
20+
21+
22+
@pytest.fixture
23+
async def http_client():
24+
async with httpx.AsyncClient(timeout=30) as client:
25+
yield client
26+
27+
28+
async def test_openai(http_client: httpx.AsyncClient):
29+
agent = Agent(OpenAIModel('gpt-3.5-turbo', http_client=http_client))
30+
result = await agent.run('What is the capital of France?')
31+
print('OpenAI response:', result.data)
32+
assert 'paris' in result.data.lower()
33+
print('OpenAI cost:', result.cost())
34+
cost = result.cost()
35+
assert cost.total_tokens is not None and cost.total_tokens > 0
36+
37+
38+
async def test_openai_stream(http_client: httpx.AsyncClient):
39+
agent = Agent(OpenAIModel('gpt-3.5-turbo', http_client=http_client))
40+
async with agent.run_stream('What is the capital of France?') as result:
41+
data = await result.get_data()
42+
print('OpenAI stream response:', data)
43+
assert 'paris' in data.lower()
44+
print('OpenAI stream cost:', result.cost())
45+
cost = result.cost()
46+
assert cost.total_tokens is not None and cost.total_tokens > 0
47+
48+
49+
class MyModel(BaseModel):
50+
city: str
51+
52+
53+
async def test_openai_structured(http_client: httpx.AsyncClient):
54+
agent = Agent(OpenAIModel('gpt-4o-mini', http_client=http_client), result_type=MyModel)
55+
result = await agent.run('What is the capital of the UK?')
56+
print('OpenAI structured response:', result.data)
57+
assert result.data.city.lower() == 'london'
58+
print('OpenAI structured cost:', result.cost())
59+
cost = result.cost()
60+
assert cost.total_tokens is not None and cost.total_tokens > 0
61+
62+
63+
async def test_gemini(http_client: httpx.AsyncClient):
64+
agent = Agent(GeminiModel('gemini-1.5-flash', http_client=http_client))
65+
result = await agent.run('What is the capital of France?')
66+
print('Gemini response:', result.data)
67+
assert 'paris' in result.data.lower()
68+
print('Gemini cost:', result.cost())
69+
cost = result.cost()
70+
assert cost.total_tokens is not None and cost.total_tokens > 0
71+
72+
73+
async def test_gemini_stream(http_client: httpx.AsyncClient):
74+
agent = Agent(GeminiModel('gemini-1.5-pro', http_client=http_client))
75+
async with agent.run_stream('What is the capital of France?') as result:
76+
data = await result.get_data()
77+
print('Gemini stream response:', data)
78+
assert 'paris' in data.lower()
79+
print('Gemini stream cost:', result.cost())
80+
cost = result.cost()
81+
assert cost.total_tokens is not None and cost.total_tokens > 0
82+
83+
84+
async def test_gemini_structured(http_client: httpx.AsyncClient):
85+
agent = Agent(GeminiModel('gemini-1.5-pro', http_client=http_client), result_type=MyModel)
86+
result = await agent.run('What is the capital of the UK?')
87+
print('Gemini structured response:', result.data)
88+
assert result.data.city.lower() == 'london'
89+
print('Gemini structured cost:', result.cost())
90+
cost = result.cost()
91+
assert cost.total_tokens is not None and cost.total_tokens > 0

0 commit comments

Comments
 (0)