Skip to content

Commit 10b3c91

Browse files
Fix handling of OpenAI system prompts in order to support o1 (#740)
Co-authored-by: David Montague <[email protected]>
1 parent 6f58f57 commit 10b3c91

File tree

4 files changed

+100
-41
lines changed

4 files changed

+100
-41
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
allows this model to be used more easily with other model types (ie, Ollama)
5252
"""
5353

54+
OpenAISystemPromptRole = Literal['system', 'developer']
55+
5456

5557
@dataclass(init=False)
5658
class OpenAIModel(Model):
@@ -63,6 +65,7 @@ class OpenAIModel(Model):
6365

6466
model_name: OpenAIModelName
6567
client: AsyncOpenAI = field(repr=False)
68+
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
6669

6770
def __init__(
6871
self,
@@ -72,6 +75,7 @@ def __init__(
7275
api_key: str | None = None,
7376
openai_client: AsyncOpenAI | None = None,
7477
http_client: AsyncHTTPClient | None = None,
78+
system_prompt_role: OpenAISystemPromptRole | None = None,
7579
):
7680
"""Initialize an OpenAI model.
7781
@@ -87,6 +91,8 @@ def __init__(
8791
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
8892
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
8993
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
94+
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
95+
In the future, this may be inferred from the model name.
9096
"""
9197
self.model_name: OpenAIModelName = model_name
9298
if openai_client is not None:
@@ -98,6 +104,7 @@ def __init__(
98104
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
99105
else:
100106
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
107+
self.system_prompt_role = system_prompt_role
101108

102109
async def agent_model(
103110
self,
@@ -115,6 +122,7 @@ async def agent_model(
115122
self.model_name,
116123
allow_text_result,
117124
tools,
125+
self.system_prompt_role,
118126
)
119127

120128
def name(self) -> str:
@@ -140,6 +148,7 @@ class OpenAIAgentModel(AgentModel):
140148
model_name: OpenAIModelName
141149
allow_text_result: bool
142150
tools: list[chat.ChatCompletionToolParam]
151+
system_prompt_role: OpenAISystemPromptRole | None
143152

144153
async def request(
145154
self, messages: list[ModelMessage], model_settings: ModelSettings | None
@@ -222,11 +231,10 @@ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionC
222231
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
223232
)
224233

225-
@classmethod
226-
def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
234+
def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
227235
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
228236
if isinstance(message, ModelRequest):
229-
yield from cls._map_user_message(message)
237+
yield from self._map_user_message(message)
230238
elif isinstance(message, ModelResponse):
231239
texts: list[str] = []
232240
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
@@ -248,11 +256,13 @@ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMess
248256
else:
249257
assert_never(message)
250258

251-
@classmethod
252-
def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
259+
def _map_user_message(self, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
253260
for part in message.parts:
254261
if isinstance(part, SystemPromptPart):
255-
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
262+
if self.system_prompt_role == 'developer':
263+
yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
264+
else:
265+
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
256266
elif isinstance(part, UserPromptPart):
257267
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
258268
elif isinstance(part, ToolReturnPart):

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ dependencies = [
4545
# WARNING if you add optional groups, please update docs/install.md
4646
logfire = ["logfire>=2.3"]
4747
graph = ["pydantic-graph==0.0.19"]
48-
openai = ["openai>=1.54.3"]
48+
openai = ["openai>=1.59.0"]
4949
cohere = ["cohere>=5.13.11"]
5050
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
5151
anthropic = ["anthropic>=0.40.0"]

tests/models/test_openai.py

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
from collections.abc import Sequence
5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
66
from datetime import datetime, timezone
77
from functools import cached_property
88
from typing import Any, Literal, cast
@@ -28,7 +28,7 @@
2828
from .mock_async_stream import MockAsyncStream
2929

3030
with try_import() as imports_successful:
31-
from openai import AsyncOpenAI
31+
from openai import NOT_GIVEN, AsyncOpenAI
3232
from openai.types import chat
3333
from openai.types.chat.chat_completion import Choice
3434
from openai.types.chat.chat_completion_chunk import (
@@ -41,7 +41,7 @@
4141
from openai.types.chat.chat_completion_message_tool_call import Function
4242
from openai.types.completion_usage import CompletionUsage, PromptTokensDetails
4343

44-
from pydantic_ai.models.openai import OpenAIModel
44+
from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole
4545

4646
pytestmark = [
4747
pytest.mark.skipif(not imports_successful(), reason='openai not installed'),
@@ -50,25 +50,26 @@
5050

5151

5252
def test_init():
53-
m = OpenAIModel('gpt-4', api_key='foobar')
53+
m = OpenAIModel('gpt-4o', api_key='foobar')
5454
assert str(m.client.base_url) == 'https://api.openai.com/v1/'
5555
assert m.client.api_key == 'foobar'
56-
assert m.name() == 'openai:gpt-4'
56+
assert m.name() == 'openai:gpt-4o'
5757

5858

5959
def test_init_with_base_url():
60-
m = OpenAIModel('gpt-4', base_url='https://example.com/v1', api_key='foobar')
60+
m = OpenAIModel('gpt-4o', base_url='https://example.com/v1', api_key='foobar')
6161
assert str(m.client.base_url) == 'https://example.com/v1/'
6262
assert m.client.api_key == 'foobar'
63-
assert m.name() == 'openai:gpt-4'
63+
assert m.name() == 'openai:gpt-4o'
6464
m.name()
6565

6666

6767
@dataclass
6868
class MockOpenAI:
6969
completions: chat.ChatCompletion | list[chat.ChatCompletion] | None = None
7070
stream: list[chat.ChatCompletionChunk] | list[list[chat.ChatCompletionChunk]] | None = None
71-
index = 0
71+
index: int = 0
72+
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
7273

7374
@cached_property
7475
def chat(self) -> Any:
@@ -86,8 +87,10 @@ def create_mock_stream(
8687
return cast(AsyncOpenAI, cls(stream=list(stream))) # pyright: ignore[reportArgumentType]
8788

8889
async def chat_completions_create( # pragma: no cover
89-
self, *_args: Any, stream: bool = False, **_kwargs: Any
90+
self, *_args: Any, stream: bool = False, **kwargs: Any
9091
) -> chat.ChatCompletion | MockAsyncStream[chat.ChatCompletionChunk]:
92+
self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})
93+
9194
if stream:
9295
assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided'
9396
# noinspection PyUnresolvedReferences
@@ -106,12 +109,19 @@ async def chat_completions_create( # pragma: no cover
106109
return response
107110

108111

112+
def get_mock_chat_completion_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str, Any]]:
113+
if isinstance(async_open_ai, MockOpenAI):
114+
return async_open_ai.chat_completion_kwargs
115+
else: # pragma: no cover
116+
raise RuntimeError('Not a MockOpenAI instance')
117+
118+
109119
def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage | None = None) -> chat.ChatCompletion:
110120
return chat.ChatCompletion(
111121
id='123',
112122
choices=[Choice(finish_reason='stop', index=0, message=message)],
113123
created=1704067200, # 2024-01-01
114-
model='gpt-4',
124+
model='gpt-4o',
115125
object='chat.completion',
116126
usage=usage,
117127
)
@@ -120,7 +130,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
120130
async def test_request_simple_success(allow_model_requests: None):
121131
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
122132
mock_client = MockOpenAI.create_mock(c)
123-
m = OpenAIModel('gpt-4', openai_client=mock_client)
133+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
124134
agent = Agent(m)
125135

126136
result = await agent.run('hello')
@@ -138,17 +148,29 @@ async def test_request_simple_success(allow_model_requests: None):
138148
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
139149
ModelResponse(
140150
parts=[TextPart(content='world')],
141-
model_name='gpt-4',
151+
model_name='gpt-4o',
142152
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
143153
),
144154
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
145155
ModelResponse(
146156
parts=[TextPart(content='world')],
147-
model_name='gpt-4',
157+
model_name='gpt-4o',
148158
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
149159
),
150160
]
151161
)
162+
assert get_mock_chat_completion_kwargs(mock_client) == [
163+
{'messages': [{'content': 'hello', 'role': 'user'}], 'model': 'gpt-4o', 'n': 1},
164+
{
165+
'messages': [
166+
{'content': 'hello', 'role': 'user'},
167+
{'content': 'world', 'role': 'assistant'},
168+
{'content': 'hello', 'role': 'user'},
169+
],
170+
'model': 'gpt-4o',
171+
'n': 1,
172+
},
173+
]
152174

153175

154176
async def test_request_simple_usage(allow_model_requests: None):
@@ -157,7 +179,7 @@ async def test_request_simple_usage(allow_model_requests: None):
157179
usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3),
158180
)
159181
mock_client = MockOpenAI.create_mock(c)
160-
m = OpenAIModel('gpt-4', openai_client=mock_client)
182+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
161183
agent = Agent(m)
162184

163185
result = await agent.run('Hello')
@@ -180,7 +202,7 @@ async def test_request_structured_response(allow_model_requests: None):
180202
)
181203
)
182204
mock_client = MockOpenAI.create_mock(c)
183-
m = OpenAIModel('gpt-4', openai_client=mock_client)
205+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
184206
agent = Agent(m, result_type=list[int])
185207

186208
result = await agent.run('Hello')
@@ -196,7 +218,7 @@ async def test_request_structured_response(allow_model_requests: None):
196218
tool_call_id='123',
197219
)
198220
],
199-
model_name='gpt-4',
221+
model_name='gpt-4o',
200222
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
201223
),
202224
ModelRequest(
@@ -256,7 +278,7 @@ async def test_request_tool_call(allow_model_requests: None):
256278
completion_message(ChatCompletionMessage(content='final response', role='assistant')),
257279
]
258280
mock_client = MockOpenAI.create_mock(responses)
259-
m = OpenAIModel('gpt-4', openai_client=mock_client)
281+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
260282
agent = Agent(m, system_prompt='this is the system prompt')
261283

262284
@agent.tool_plain
@@ -284,7 +306,7 @@ async def get_location(loc_name: str) -> str:
284306
tool_call_id='1',
285307
)
286308
],
287-
model_name='gpt-4',
309+
model_name='gpt-4o',
288310
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
289311
),
290312
ModelRequest(
@@ -305,7 +327,7 @@ async def get_location(loc_name: str) -> str:
305327
tool_call_id='2',
306328
)
307329
],
308-
model_name='gpt-4',
330+
model_name='gpt-4o',
309331
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
310332
),
311333
ModelRequest(
@@ -320,7 +342,7 @@ async def get_location(loc_name: str) -> str:
320342
),
321343
ModelResponse(
322344
parts=[TextPart(content='final response')],
323-
model_name='gpt-4',
345+
model_name='gpt-4o',
324346
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
325347
),
326348
]
@@ -346,7 +368,7 @@ def chunk(delta: list[ChoiceDelta], finish_reason: FinishReason | None = None) -
346368
ChunkChoice(index=index, delta=delta, finish_reason=finish_reason) for index, delta in enumerate(delta)
347369
],
348370
created=1704067200, # 2024-01-01
349-
model='gpt-4',
371+
model='gpt-4o',
350372
object='chat.completion.chunk',
351373
usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3),
352374
)
@@ -359,7 +381,7 @@ def text_chunk(text: str, finish_reason: FinishReason | None = None) -> chat.Cha
359381
async def test_stream_text(allow_model_requests: None):
360382
stream = text_chunk('hello '), text_chunk('world'), chunk([])
361383
mock_client = MockOpenAI.create_mock_stream(stream)
362-
m = OpenAIModel('gpt-4', openai_client=mock_client)
384+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
363385
agent = Agent(m)
364386

365387
async with agent.run_stream('') as result:
@@ -372,7 +394,7 @@ async def test_stream_text(allow_model_requests: None):
372394
async def test_stream_text_finish_reason(allow_model_requests: None):
373395
stream = text_chunk('hello '), text_chunk('world'), text_chunk('.', finish_reason='stop')
374396
mock_client = MockOpenAI.create_mock_stream(stream)
375-
m = OpenAIModel('gpt-4', openai_client=mock_client)
397+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
376398
agent = Agent(m)
377399

378400
async with agent.run_stream('') as result:
@@ -419,7 +441,7 @@ async def test_stream_structured(allow_model_requests: None):
419441
chunk([]),
420442
)
421443
mock_client = MockOpenAI.create_mock_stream(stream)
422-
m = OpenAIModel('gpt-4', openai_client=mock_client)
444+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
423445
agent = Agent(m, result_type=MyTypedDict)
424446

425447
async with agent.run_stream('') as result:
@@ -447,7 +469,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
447469
struc_chunk(None, None, finish_reason='stop'),
448470
)
449471
mock_client = MockOpenAI.create_mock_stream(stream)
450-
m = OpenAIModel('gpt-4', openai_client=mock_client)
472+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
451473
agent = Agent(m, result_type=MyTypedDict)
452474

453475
async with agent.run_stream('') as result:
@@ -467,7 +489,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
467489
async def test_no_content(allow_model_requests: None):
468490
stream = chunk([ChoiceDelta()]), chunk([ChoiceDelta()])
469491
mock_client = MockOpenAI.create_mock_stream(stream)
470-
m = OpenAIModel('gpt-4', openai_client=mock_client)
492+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
471493
agent = Agent(m, result_type=MyTypedDict)
472494

473495
with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
@@ -482,11 +504,38 @@ async def test_no_delta(allow_model_requests: None):
482504
text_chunk('world'),
483505
)
484506
mock_client = MockOpenAI.create_mock_stream(stream)
485-
m = OpenAIModel('gpt-4', openai_client=mock_client)
507+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
486508
agent = Agent(m)
487509

488510
async with agent.run_stream('') as result:
489511
assert not result.is_complete
490512
assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world'])
491513
assert result.is_complete
492514
assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9))
515+
516+
517+
@pytest.mark.parametrize('system_prompt_role', ['system', 'developer', None])
518+
async def test_system_prompt_role(
519+
allow_model_requests: None, system_prompt_role: OpenAISystemPromptRole | None
520+
) -> None:
521+
"""Testing the system prompt role for OpenAI models is properly set / inferred."""
522+
523+
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
524+
mock_client = MockOpenAI.create_mock(c)
525+
m = OpenAIModel('gpt-4o', system_prompt_role=system_prompt_role, openai_client=mock_client)
526+
assert m.system_prompt_role == system_prompt_role
527+
528+
agent = Agent(m, system_prompt='some instructions')
529+
result = await agent.run('hello')
530+
assert result.data == 'world'
531+
532+
assert get_mock_chat_completion_kwargs(mock_client) == [
533+
{
534+
'messages': [
535+
{'content': 'some instructions', 'role': system_prompt_role or 'system'},
536+
{'content': 'hello', 'role': 'user'},
537+
],
538+
'model': 'gpt-4o',
539+
'n': 1,
540+
}
541+
]

0 commit comments

Comments
 (0)