1
1
from __future__ import annotations as _annotations
2
2
3
3
import json
4
- from dataclasses import dataclass
4
+ from dataclasses import dataclass , field
5
5
from datetime import timezone
6
6
from functools import cached_property
7
7
from typing import Any , cast
22
22
UserPromptPart ,
23
23
)
24
24
from pydantic_ai .result import Usage
25
+ from pydantic_ai .settings import ModelSettings
25
26
26
27
from ..conftest import IsNow , try_import
27
28
28
29
with try_import () as imports_successful :
29
- from anthropic import AsyncAnthropic
30
+ from anthropic import NOT_GIVEN , AsyncAnthropic
30
31
from anthropic .types import (
31
32
ContentBlock ,
32
33
Message as AnthropicMessage ,
@@ -53,6 +54,7 @@ def test_init():
53
54
class MockAnthropic :
54
55
messages_ : AnthropicMessage | list [AnthropicMessage ] | None = None
55
56
index = 0
57
+ chat_completion_kwargs : list [dict [str , Any ]] = field (default_factory = list )
56
58
57
59
@cached_property
58
60
def messages (self ) -> Any :
@@ -62,7 +64,9 @@ def messages(self) -> Any:
62
64
def create_mock (cls , messages_ : AnthropicMessage | list [AnthropicMessage ]) -> AsyncAnthropic :
63
65
return cast (AsyncAnthropic , cls (messages_ = messages_ ))
64
66
65
- async def messages_create (self , * _args : Any , ** _kwargs : Any ) -> AnthropicMessage :
67
+ async def messages_create (self , * _args : Any , ** kwargs : Any ) -> AnthropicMessage :
68
+ self .chat_completion_kwargs .append ({k : v for k , v in kwargs .items () if v is not NOT_GIVEN })
69
+
66
70
assert self .messages_ is not None , '`messages` must be provided'
67
71
if isinstance (self .messages_ , list ):
68
72
response = self .messages_ [self .index ]
@@ -257,3 +261,40 @@ async def get_location(loc_name: str) -> str:
257
261
),
258
262
]
259
263
)
264
+
265
+
266
+ def get_mock_chat_completion_kwargs (async_anthropic : AsyncAnthropic ) -> list [dict [str , Any ]]:
267
+ if isinstance (async_anthropic , MockAnthropic ):
268
+ return async_anthropic .chat_completion_kwargs
269
+ else : # pragma: no cover
270
+ raise RuntimeError ('Not a MockOpenAI instance' )
271
+
272
+
273
+ @pytest .mark .parametrize ('parallel_tool_calls' , [True , False ])
274
+ async def test_parallel_tool_calls (allow_model_requests : None , parallel_tool_calls : bool ) -> None :
275
+ responses = [
276
+ completion_message (
277
+ [ToolUseBlock (id = '1' , input = {'loc_name' : 'San Francisco' }, name = 'get_location' , type = 'tool_use' )],
278
+ usage = AnthropicUsage (input_tokens = 2 , output_tokens = 1 ),
279
+ ),
280
+ completion_message (
281
+ [TextBlock (text = 'final response' , type = 'text' )],
282
+ usage = AnthropicUsage (input_tokens = 3 , output_tokens = 5 ),
283
+ ),
284
+ ]
285
+
286
+ mock_client = MockAnthropic .create_mock (responses )
287
+ m = AnthropicModel ('claude-3-5-haiku-latest' , anthropic_client = mock_client )
288
+ agent = Agent (m , model_settings = ModelSettings (parallel_tool_calls = parallel_tool_calls ))
289
+
290
+ @agent .tool_plain
291
+ async def get_location (loc_name : str ) -> str :
292
+ if loc_name == 'London' :
293
+ return json .dumps ({'lat' : 51 , 'lng' : 0 })
294
+ else :
295
+ raise ModelRetry ('Wrong location, please try again' )
296
+
297
+ await agent .run ('hello' )
298
+ assert get_mock_chat_completion_kwargs (mock_client )[0 ]['tool_choice' ]['disable_parallel_tool_use' ] == (
299
+ not parallel_tool_calls
300
+ )
0 commit comments