Skip to content

Commit 68b74f2

Browse files
authored
Move system_prompt_role from OpenAIModel to OpenAIModelProfile (#2573)
1 parent 9bbbf0c commit 68b74f2

File tree

4 files changed

+178
-16
lines changed

4 files changed

+178
-16
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Any, Literal, Union, cast, overload
1010

1111
from pydantic import ValidationError
12-
from typing_extensions import assert_never
12+
from typing_extensions import assert_never, deprecated
1313

1414
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
1515
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
@@ -40,7 +40,7 @@
4040
VideoUrl,
4141
)
4242
from ..profiles import ModelProfile, ModelProfileSpec
43-
from ..profiles.openai import OpenAIModelProfile
43+
from ..profiles.openai import OpenAIModelProfile, OpenAISystemPromptRole
4444
from ..providers import Provider, infer_provider
4545
from ..settings import ModelSettings
4646
from ..tools import ToolDefinition
@@ -100,8 +100,6 @@
100100
allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
101101
"""
102102

103-
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
104-
105103

106104
class OpenAIModelSettings(ModelSettings, total=False):
107105
"""Settings used for an OpenAI model request."""
@@ -196,11 +194,60 @@ class OpenAIModel(Model):
196194
"""
197195

198196
client: AsyncOpenAI = field(repr=False)
199-
system_prompt_role: OpenAISystemPromptRole | None = field(default=None, repr=False)
200197

201198
_model_name: OpenAIModelName = field(repr=False)
202199
_system: str = field(default='openai', repr=False)
203200

201+
@overload
202+
def __init__(
203+
self,
204+
model_name: OpenAIModelName,
205+
*,
206+
provider: Literal[
207+
'openai',
208+
'deepseek',
209+
'azure',
210+
'openrouter',
211+
'moonshotai',
212+
'vercel',
213+
'grok',
214+
'fireworks',
215+
'together',
216+
'heroku',
217+
'github',
218+
'ollama',
219+
]
220+
| Provider[AsyncOpenAI] = 'openai',
221+
profile: ModelProfileSpec | None = None,
222+
settings: ModelSettings | None = None,
223+
) -> None: ...
224+
225+
@deprecated('Set the `system_prompt_role` in the `OpenAIModelProfile` instead.')
226+
@overload
227+
def __init__(
228+
self,
229+
model_name: OpenAIModelName,
230+
*,
231+
provider: Literal[
232+
'openai',
233+
'deepseek',
234+
'azure',
235+
'openrouter',
236+
'moonshotai',
237+
'vercel',
238+
'grok',
239+
'fireworks',
240+
'together',
241+
'heroku',
242+
'github',
243+
'ollama',
244+
]
245+
| Provider[AsyncOpenAI] = 'openai',
246+
profile: ModelProfileSpec | None = None,
247+
system_prompt_role: OpenAISystemPromptRole | None = None,
248+
settings: ModelSettings | None = None,
249+
) -> None: ...
250+
204251
def __init__(
205252
self,
206253
model_name: OpenAIModelName,
@@ -242,14 +289,20 @@ def __init__(
242289
provider = infer_provider(provider)
243290
self.client = provider.client
244291

245-
self.system_prompt_role = system_prompt_role
246-
247292
super().__init__(settings=settings, profile=profile or provider.model_profile)
248293

294+
if system_prompt_role is not None:
295+
self.profile = OpenAIModelProfile(openai_system_prompt_role=system_prompt_role).update(self.profile)
296+
249297
@property
250298
def base_url(self) -> str:
251299
return str(self.client.base_url)
252300

301+
@property
302+
@deprecated('Set the `system_prompt_role` in the `OpenAIModelProfile` instead.')
303+
def system_prompt_role(self) -> OpenAISystemPromptRole | None:
304+
return OpenAIModelProfile.from_profile(self.profile).openai_system_prompt_role
305+
253306
async def request(
254307
self,
255308
messages: list[ModelMessage],
@@ -561,9 +614,10 @@ def _map_tool_definition(self, f: ToolDefinition) -> chat.ChatCompletionToolPara
561614
async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
562615
for part in message.parts:
563616
if isinstance(part, SystemPromptPart):
564-
if self.system_prompt_role == 'developer':
617+
system_prompt_role = OpenAIModelProfile.from_profile(self.profile).openai_system_prompt_role
618+
if system_prompt_role == 'developer':
565619
yield chat.ChatCompletionDeveloperMessageParam(role='developer', content=part.content)
566-
elif self.system_prompt_role == 'user':
620+
elif system_prompt_role == 'user':
567621
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
568622
else:
569623
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
@@ -659,7 +713,6 @@ class OpenAIResponsesModel(Model):
659713
"""
660714

661715
client: AsyncOpenAI = field(repr=False)
662-
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
663716

664717
_model_name: OpenAIModelName = field(repr=False)
665718
_system: str = field(default='openai', repr=False)

pydantic_ai_slim/pydantic_ai/profiles/openai.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import re
44
from dataclasses import dataclass
5-
from typing import Any
5+
from typing import Any, Literal
66

77
from . import ModelProfile
88
from ._json_schema import JsonSchema, JsonSchemaTransformer
99

10+
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
11+
1012

1113
@dataclass
1214
class OpenAIModelProfile(ModelProfile):
@@ -26,8 +28,10 @@ class OpenAIModelProfile(ModelProfile):
2628
# safe to pass that value along. Default is `True` to preserve existing
2729
# behaviour for OpenAI itself and most providers.
2830
openai_supports_tool_choice_required: bool = True
29-
"""Whether the provider accepts the value ``tool_choice='required'`` in the
30-
request payload."""
31+
"""Whether the provider accepts the value ``tool_choice='required'`` in the request payload."""
32+
33+
openai_system_prompt_role: OpenAISystemPromptRole | None = None
34+
"""The role to use for the system prompt message. If not provided, defaults to `'system'`."""
3135

3236

3337
def openai_model_profile(model_name: str) -> ModelProfile:
@@ -36,11 +40,17 @@ def openai_model_profile(model_name: str) -> ModelProfile:
3640
# Structured Outputs (output mode 'native') is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later.
3741
# We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `native` is only used
3842
# when the user specifically uses the `NativeOutput` marker, so an error from the API is acceptable.
43+
44+
# The o1-mini model doesn't support the `system` role, so we default to `user`.
45+
# See https://github.com/pydantic/pydantic-ai/issues/974 for more details.
46+
openai_system_prompt_role = 'user' if model_name.startswith('o1-mini') else None
47+
3948
return OpenAIModelProfile(
4049
json_schema_transformer=OpenAIJsonSchemaTransformer,
4150
supports_json_schema_output=True,
4251
supports_json_object_output=True,
4352
openai_supports_sampling_settings=not is_reasoning_model,
53+
openai_system_prompt_role=openai_system_prompt_role,
4454
)
4555

4656

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '162'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.openai.com
16+
method: POST
17+
parsed_body:
18+
messages:
19+
- content: You are a helpful assistant.
20+
role: user
21+
- content: What's the capital of France?
22+
role: user
23+
model: o1-mini
24+
stream: false
25+
uri: https://api.openai.com/v1/chat/completions
26+
response:
27+
headers:
28+
access-control-expose-headers:
29+
- X-Request-ID
30+
alt-svc:
31+
- h3=":443"; ma=86400
32+
connection:
33+
- keep-alive
34+
content-length:
35+
- '818'
36+
content-type:
37+
- application/json
38+
openai-organization:
39+
- pydantic-28gund
40+
openai-processing-ms:
41+
- '2319'
42+
openai-project:
43+
- proj_dKobscVY9YJxeEaDJen54e3d
44+
openai-version:
45+
- '2020-10-01'
46+
strict-transport-security:
47+
- max-age=31536000; includeSubDomains; preload
48+
transfer-encoding:
49+
- chunked
50+
parsed_body:
51+
choices:
52+
- finish_reason: stop
53+
index: 0
54+
message:
55+
annotations: []
56+
content: The capital of France is **Paris**.
57+
refusal: null
58+
role: assistant
59+
created: 1755256071
60+
id: chatcmpl-C4mZjhnq5PQ6hDaKfMKb1WtXeSYzu
61+
model: o1-mini-2024-09-12
62+
object: chat.completion
63+
service_tier: default
64+
system_fingerprint: fp_79455e3cfb
65+
usage:
66+
completion_tokens: 212
67+
completion_tokens_details:
68+
accepted_prediction_tokens: 0
69+
audio_tokens: 0
70+
reasoning_tokens: 192
71+
rejected_prediction_tokens: 0
72+
prompt_tokens: 30
73+
prompt_tokens_details:
74+
audio_tokens: 0
75+
cached_tokens: 0
76+
total_tokens: 242
77+
status:
78+
code: 200
79+
message: OK
80+
version: 1

tests/models/test_openai.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,7 @@ async def test_no_delta(allow_model_requests: None):
673673
assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3))
674674

675675

676+
@pytest.mark.filterwarnings('ignore:Set the `system_prompt_role` in the `OpenAIModelProfile` instead.')
676677
@pytest.mark.parametrize('system_prompt_role', ['system', 'developer', 'user', None])
677678
async def test_system_prompt_role(
678679
allow_model_requests: None, system_prompt_role: OpenAISystemPromptRole | None
@@ -681,8 +682,8 @@ async def test_system_prompt_role(
681682

682683
c = completion_message(ChatCompletionMessage(content='world', role='assistant'))
683684
mock_client = MockOpenAI.create_mock(c)
684-
m = OpenAIModel('gpt-4o', system_prompt_role=system_prompt_role, provider=OpenAIProvider(openai_client=mock_client))
685-
assert m.system_prompt_role == system_prompt_role
685+
m = OpenAIModel('gpt-4o', system_prompt_role=system_prompt_role, provider=OpenAIProvider(openai_client=mock_client)) # type: ignore[reportDeprecated]
686+
assert m.system_prompt_role == system_prompt_role # type: ignore[reportDeprecated]
686687

687688
agent = Agent(m, system_prompt='some instructions')
688689
result = await agent.run('hello')
@@ -701,13 +702,31 @@ async def test_system_prompt_role(
701702
]
702703

703704

705+
async def test_system_prompt_role_o1_mini(allow_model_requests: None, openai_api_key: str):
706+
model = OpenAIModel('o1-mini', provider=OpenAIProvider(api_key=openai_api_key))
707+
agent = Agent(model=model, system_prompt='You are a helpful assistant.')
708+
709+
result = await agent.run("What's the capital of France?")
710+
assert result.output == snapshot('The capital of France is **Paris**.')
711+
712+
713+
async def test_openai_pass_custom_system_prompt_role(allow_model_requests: None, openai_api_key: str):
714+
profile = ModelProfile(supports_tools=False)
715+
model = OpenAIModel( # type: ignore[reportDeprecated]
716+
'o1-mini', profile=profile, provider=OpenAIProvider(api_key=openai_api_key), system_prompt_role='user'
717+
)
718+
profile = OpenAIModelProfile.from_profile(model.profile)
719+
assert profile.openai_system_prompt_role == 'user'
720+
assert profile.supports_tools is False
721+
722+
704723
@pytest.mark.parametrize('system_prompt_role', ['system', 'developer'])
705724
async def test_openai_o1_mini_system_role(
706725
allow_model_requests: None,
707726
system_prompt_role: Literal['system', 'developer'],
708727
openai_api_key: str,
709728
) -> None:
710-
model = OpenAIModel(
729+
model = OpenAIModel( # type: ignore[reportDeprecated]
711730
'o1-mini', provider=OpenAIProvider(api_key=openai_api_key), system_prompt_role=system_prompt_role
712731
)
713732
agent = Agent(model=model, system_prompt='You are a helpful assistant.')

0 commit comments

Comments
 (0)