2
2
3
3
import json
4
4
from collections .abc import Sequence
5
- from dataclasses import dataclass
5
+ from dataclasses import dataclass , field
6
6
from datetime import datetime , timezone
7
7
from functools import cached_property
8
8
from typing import Any , Literal , cast
28
28
from .mock_async_stream import MockAsyncStream
29
29
30
30
with try_import () as imports_successful :
31
- from openai import AsyncOpenAI
31
+ from openai import NOT_GIVEN , AsyncOpenAI
32
32
from openai .types import chat
33
33
from openai .types .chat .chat_completion import Choice
34
34
from openai .types .chat .chat_completion_chunk import (
41
41
from openai .types .chat .chat_completion_message_tool_call import Function
42
42
from openai .types .completion_usage import CompletionUsage , PromptTokensDetails
43
43
44
- from pydantic_ai .models .openai import OpenAIModel
44
+ from pydantic_ai .models .openai import OpenAIModel , OpenAISystemPromptRole
45
45
46
46
pytestmark = [
47
47
pytest .mark .skipif (not imports_successful (), reason = 'openai not installed' ),
50
50
51
51
52
52
def test_init ():
53
- m = OpenAIModel ('gpt-4 ' , api_key = 'foobar' )
53
+ m = OpenAIModel ('gpt-4o ' , api_key = 'foobar' )
54
54
assert str (m .client .base_url ) == 'https://api.openai.com/v1/'
55
55
assert m .client .api_key == 'foobar'
56
- assert m .name () == 'openai:gpt-4 '
56
+ assert m .name () == 'openai:gpt-4o '
57
57
58
58
59
59
def test_init_with_base_url ():
60
- m = OpenAIModel ('gpt-4 ' , base_url = 'https://example.com/v1' , api_key = 'foobar' )
60
+ m = OpenAIModel ('gpt-4o ' , base_url = 'https://example.com/v1' , api_key = 'foobar' )
61
61
assert str (m .client .base_url ) == 'https://example.com/v1/'
62
62
assert m .client .api_key == 'foobar'
63
- assert m .name () == 'openai:gpt-4 '
63
+ assert m .name () == 'openai:gpt-4o '
64
64
m .name ()
65
65
66
66
67
67
@dataclass
68
68
class MockOpenAI :
69
69
completions : chat .ChatCompletion | list [chat .ChatCompletion ] | None = None
70
70
stream : list [chat .ChatCompletionChunk ] | list [list [chat .ChatCompletionChunk ]] | None = None
71
- index = 0
71
+ index : int = 0
72
+ chat_completion_kwargs : list [dict [str , Any ]] = field (default_factory = list )
72
73
73
74
@cached_property
74
75
def chat (self ) -> Any :
@@ -86,8 +87,10 @@ def create_mock_stream(
86
87
return cast (AsyncOpenAI , cls (stream = list (stream ))) # pyright: ignore[reportArgumentType]
87
88
88
89
async def chat_completions_create ( # pragma: no cover
89
- self , * _args : Any , stream : bool = False , ** _kwargs : Any
90
+ self , * _args : Any , stream : bool = False , ** kwargs : Any
90
91
) -> chat .ChatCompletion | MockAsyncStream [chat .ChatCompletionChunk ]:
92
+ self .chat_completion_kwargs .append ({k : v for k , v in kwargs .items () if v is not NOT_GIVEN })
93
+
91
94
if stream :
92
95
assert self .stream is not None , 'you can only used `stream=True` if `stream` is provided'
93
96
# noinspection PyUnresolvedReferences
@@ -106,12 +109,19 @@ async def chat_completions_create( # pragma: no cover
106
109
return response
107
110
108
111
112
+ def get_mock_chat_completion_kwargs (async_open_ai : AsyncOpenAI ) -> list [dict [str , Any ]]:
113
+ if isinstance (async_open_ai , MockOpenAI ):
114
+ return async_open_ai .chat_completion_kwargs
115
+ else : # pragma: no cover
116
+ raise RuntimeError ('Not a MockOpenAI instance' )
117
+
118
+
109
119
def completion_message (message : ChatCompletionMessage , * , usage : CompletionUsage | None = None ) -> chat .ChatCompletion :
110
120
return chat .ChatCompletion (
111
121
id = '123' ,
112
122
choices = [Choice (finish_reason = 'stop' , index = 0 , message = message )],
113
123
created = 1704067200 , # 2024-01-01
114
- model = 'gpt-4 ' ,
124
+ model = 'gpt-4o ' ,
115
125
object = 'chat.completion' ,
116
126
usage = usage ,
117
127
)
@@ -120,7 +130,7 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage
120
130
async def test_request_simple_success (allow_model_requests : None ):
121
131
c = completion_message (ChatCompletionMessage (content = 'world' , role = 'assistant' ))
122
132
mock_client = MockOpenAI .create_mock (c )
123
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
133
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
124
134
agent = Agent (m )
125
135
126
136
result = await agent .run ('hello' )
@@ -138,17 +148,29 @@ async def test_request_simple_success(allow_model_requests: None):
138
148
ModelRequest (parts = [UserPromptPart (content = 'hello' , timestamp = IsNow (tz = timezone .utc ))]),
139
149
ModelResponse (
140
150
parts = [TextPart (content = 'world' )],
141
- model_name = 'gpt-4 ' ,
151
+ model_name = 'gpt-4o ' ,
142
152
timestamp = datetime (2024 , 1 , 1 , 0 , 0 , tzinfo = timezone .utc ),
143
153
),
144
154
ModelRequest (parts = [UserPromptPart (content = 'hello' , timestamp = IsNow (tz = timezone .utc ))]),
145
155
ModelResponse (
146
156
parts = [TextPart (content = 'world' )],
147
- model_name = 'gpt-4 ' ,
157
+ model_name = 'gpt-4o ' ,
148
158
timestamp = datetime (2024 , 1 , 1 , 0 , 0 , tzinfo = timezone .utc ),
149
159
),
150
160
]
151
161
)
162
+ assert get_mock_chat_completion_kwargs (mock_client ) == [
163
+ {'messages' : [{'content' : 'hello' , 'role' : 'user' }], 'model' : 'gpt-4o' , 'n' : 1 },
164
+ {
165
+ 'messages' : [
166
+ {'content' : 'hello' , 'role' : 'user' },
167
+ {'content' : 'world' , 'role' : 'assistant' },
168
+ {'content' : 'hello' , 'role' : 'user' },
169
+ ],
170
+ 'model' : 'gpt-4o' ,
171
+ 'n' : 1 ,
172
+ },
173
+ ]
152
174
153
175
154
176
async def test_request_simple_usage (allow_model_requests : None ):
@@ -157,7 +179,7 @@ async def test_request_simple_usage(allow_model_requests: None):
157
179
usage = CompletionUsage (completion_tokens = 1 , prompt_tokens = 2 , total_tokens = 3 ),
158
180
)
159
181
mock_client = MockOpenAI .create_mock (c )
160
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
182
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
161
183
agent = Agent (m )
162
184
163
185
result = await agent .run ('Hello' )
@@ -180,7 +202,7 @@ async def test_request_structured_response(allow_model_requests: None):
180
202
)
181
203
)
182
204
mock_client = MockOpenAI .create_mock (c )
183
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
205
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
184
206
agent = Agent (m , result_type = list [int ])
185
207
186
208
result = await agent .run ('Hello' )
@@ -196,7 +218,7 @@ async def test_request_structured_response(allow_model_requests: None):
196
218
tool_call_id = '123' ,
197
219
)
198
220
],
199
- model_name = 'gpt-4 ' ,
221
+ model_name = 'gpt-4o ' ,
200
222
timestamp = datetime (2024 , 1 , 1 , tzinfo = timezone .utc ),
201
223
),
202
224
ModelRequest (
@@ -256,7 +278,7 @@ async def test_request_tool_call(allow_model_requests: None):
256
278
completion_message (ChatCompletionMessage (content = 'final response' , role = 'assistant' )),
257
279
]
258
280
mock_client = MockOpenAI .create_mock (responses )
259
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
281
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
260
282
agent = Agent (m , system_prompt = 'this is the system prompt' )
261
283
262
284
@agent .tool_plain
@@ -284,7 +306,7 @@ async def get_location(loc_name: str) -> str:
284
306
tool_call_id = '1' ,
285
307
)
286
308
],
287
- model_name = 'gpt-4 ' ,
309
+ model_name = 'gpt-4o ' ,
288
310
timestamp = datetime (2024 , 1 , 1 , tzinfo = timezone .utc ),
289
311
),
290
312
ModelRequest (
@@ -305,7 +327,7 @@ async def get_location(loc_name: str) -> str:
305
327
tool_call_id = '2' ,
306
328
)
307
329
],
308
- model_name = 'gpt-4 ' ,
330
+ model_name = 'gpt-4o ' ,
309
331
timestamp = datetime (2024 , 1 , 1 , tzinfo = timezone .utc ),
310
332
),
311
333
ModelRequest (
@@ -320,7 +342,7 @@ async def get_location(loc_name: str) -> str:
320
342
),
321
343
ModelResponse (
322
344
parts = [TextPart (content = 'final response' )],
323
- model_name = 'gpt-4 ' ,
345
+ model_name = 'gpt-4o ' ,
324
346
timestamp = datetime (2024 , 1 , 1 , tzinfo = timezone .utc ),
325
347
),
326
348
]
@@ -346,7 +368,7 @@ def chunk(delta: list[ChoiceDelta], finish_reason: FinishReason | None = None) -
346
368
ChunkChoice (index = index , delta = delta , finish_reason = finish_reason ) for index , delta in enumerate (delta )
347
369
],
348
370
created = 1704067200 , # 2024-01-01
349
- model = 'gpt-4 ' ,
371
+ model = 'gpt-4o ' ,
350
372
object = 'chat.completion.chunk' ,
351
373
usage = CompletionUsage (completion_tokens = 1 , prompt_tokens = 2 , total_tokens = 3 ),
352
374
)
@@ -359,7 +381,7 @@ def text_chunk(text: str, finish_reason: FinishReason | None = None) -> chat.Cha
359
381
async def test_stream_text (allow_model_requests : None ):
360
382
stream = text_chunk ('hello ' ), text_chunk ('world' ), chunk ([])
361
383
mock_client = MockOpenAI .create_mock_stream (stream )
362
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
384
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
363
385
agent = Agent (m )
364
386
365
387
async with agent .run_stream ('' ) as result :
@@ -372,7 +394,7 @@ async def test_stream_text(allow_model_requests: None):
372
394
async def test_stream_text_finish_reason (allow_model_requests : None ):
373
395
stream = text_chunk ('hello ' ), text_chunk ('world' ), text_chunk ('.' , finish_reason = 'stop' )
374
396
mock_client = MockOpenAI .create_mock_stream (stream )
375
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
397
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
376
398
agent = Agent (m )
377
399
378
400
async with agent .run_stream ('' ) as result :
@@ -419,7 +441,7 @@ async def test_stream_structured(allow_model_requests: None):
419
441
chunk ([]),
420
442
)
421
443
mock_client = MockOpenAI .create_mock_stream (stream )
422
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
444
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
423
445
agent = Agent (m , result_type = MyTypedDict )
424
446
425
447
async with agent .run_stream ('' ) as result :
@@ -447,7 +469,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
447
469
struc_chunk (None , None , finish_reason = 'stop' ),
448
470
)
449
471
mock_client = MockOpenAI .create_mock_stream (stream )
450
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
472
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
451
473
agent = Agent (m , result_type = MyTypedDict )
452
474
453
475
async with agent .run_stream ('' ) as result :
@@ -467,7 +489,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
467
489
async def test_no_content (allow_model_requests : None ):
468
490
stream = chunk ([ChoiceDelta ()]), chunk ([ChoiceDelta ()])
469
491
mock_client = MockOpenAI .create_mock_stream (stream )
470
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
492
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
471
493
agent = Agent (m , result_type = MyTypedDict )
472
494
473
495
with pytest .raises (UnexpectedModelBehavior , match = 'Received empty model response' ):
@@ -482,11 +504,38 @@ async def test_no_delta(allow_model_requests: None):
482
504
text_chunk ('world' ),
483
505
)
484
506
mock_client = MockOpenAI .create_mock_stream (stream )
485
- m = OpenAIModel ('gpt-4 ' , openai_client = mock_client )
507
+ m = OpenAIModel ('gpt-4o ' , openai_client = mock_client )
486
508
agent = Agent (m )
487
509
488
510
async with agent .run_stream ('' ) as result :
489
511
assert not result .is_complete
490
512
assert [c async for c in result .stream_text (debounce_by = None )] == snapshot (['hello ' , 'hello world' ])
491
513
assert result .is_complete
492
514
assert result .usage () == snapshot (Usage (requests = 1 , request_tokens = 6 , response_tokens = 3 , total_tokens = 9 ))
515
+
516
+
517
+ @pytest .mark .parametrize ('system_prompt_role' , ['system' , 'developer' , None ])
518
+ async def test_system_prompt_role (
519
+ allow_model_requests : None , system_prompt_role : OpenAISystemPromptRole | None
520
+ ) -> None :
521
+ """Testing the system prompt role for OpenAI models is properly set / inferred."""
522
+
523
+ c = completion_message (ChatCompletionMessage (content = 'world' , role = 'assistant' ))
524
+ mock_client = MockOpenAI .create_mock (c )
525
+ m = OpenAIModel ('gpt-4o' , system_prompt_role = system_prompt_role , openai_client = mock_client )
526
+ assert m .system_prompt_role == system_prompt_role
527
+
528
+ agent = Agent (m , system_prompt = 'some instructions' )
529
+ result = await agent .run ('hello' )
530
+ assert result .data == 'world'
531
+
532
+ assert get_mock_chat_completion_kwargs (mock_client ) == [
533
+ {
534
+ 'messages' : [
535
+ {'content' : 'some instructions' , 'role' : system_prompt_role or 'system' },
536
+ {'content' : 'hello' , 'role' : 'user' },
537
+ ],
538
+ 'model' : 'gpt-4o' ,
539
+ 'n' : 1 ,
540
+ }
541
+ ]
0 commit comments