Skip to content

Commit 5894c38

Browse files
Support parallel_tool_calls in ModelSettings (#750)
1 parent eba8a7d commit 5894c38

File tree

6 files changed

+97
-12
lines changed

6 files changed

+97
-12
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,16 +186,22 @@ async def _messages_create(
186186
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
187187
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
188188
# standalone function to make it easier to override
189+
model_settings = model_settings or {}
190+
191+
tool_choice: ToolChoiceParam | None
192+
189193
if not self.tools:
190-
tool_choice: ToolChoiceParam | None = None
191-
elif not self.allow_text_result:
192-
tool_choice = {'type': 'any'}
194+
tool_choice = None
193195
else:
194-
tool_choice = {'type': 'auto'}
196+
if not self.allow_text_result:
197+
tool_choice = {'type': 'any'}
198+
else:
199+
tool_choice = {'type': 'auto'}
195200

196-
system_prompt, anthropic_messages = self._map_message(messages)
201+
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
202+
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
197203

198-
model_settings = model_settings or {}
204+
system_prompt, anthropic_messages = self._map_message(messages)
199205

200206
return await self.client.messages.create(
201207
max_tokens=model_settings.get('max_tokens', 1024),

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ async def _completions_create(
197197
model=str(self.model_name),
198198
messages=groq_messages,
199199
n=1,
200-
parallel_tool_calls=True if self.tools else NOT_GIVEN,
200+
parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
201201
tools=self.tools or NOT_GIVEN,
202202
tool_choice=tool_choice or NOT_GIVEN,
203203
stream=stream,

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ async def _completions_create(
195195
model=self.model_name,
196196
messages=openai_messages,
197197
n=1,
198-
parallel_tool_calls=True if self.tools else NOT_GIVEN,
198+
parallel_tool_calls=model_settings.get('parallel_tool_calls', True if self.tools else NOT_GIVEN),
199199
tools=self.tools or NOT_GIVEN,
200200
tool_choice=tool_choice or NOT_GIVEN,
201201
stream=stream,

pydantic_ai_slim/pydantic_ai/settings.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
class ModelSettings(TypedDict, total=False):
1313
"""Settings to configure an LLM.
1414
15-
Here we include only settings which apply to multiple models / model providers.
15+
Here we include only settings which apply to multiple models / model providers,
16+
though not all of these settings are supported by all models.
1617
"""
1718

1819
max_tokens: int
@@ -25,6 +26,7 @@ class ModelSettings(TypedDict, total=False):
2526
* OpenAI
2627
* Groq
2728
* Cohere
29+
* Mistral
2830
"""
2931

3032
temperature: float
@@ -42,6 +44,7 @@ class ModelSettings(TypedDict, total=False):
4244
* OpenAI
4345
* Groq
4446
* Cohere
47+
* Mistral
4548
"""
4649

4750
top_p: float
@@ -58,6 +61,7 @@ class ModelSettings(TypedDict, total=False):
5861
* OpenAI
5962
* Groq
6063
* Cohere
64+
* Mistral
6165
"""
6266

6367
timeout: float | Timeout
@@ -69,6 +73,16 @@ class ModelSettings(TypedDict, total=False):
6973
* Anthropic
7074
* OpenAI
7175
* Groq
76+
* Mistral
77+
"""
78+
79+
parallel_tool_calls: bool
80+
"""Whether to allow parallel tool calls.
81+
82+
Supported by:
83+
* OpenAI
84+
* Groq
85+
* Anthropic
7286
"""
7387

7488

tests/models/test_anthropic.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import json
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from datetime import timezone
66
from functools import cached_property
77
from typing import Any, cast
@@ -22,11 +22,12 @@
2222
UserPromptPart,
2323
)
2424
from pydantic_ai.result import Usage
25+
from pydantic_ai.settings import ModelSettings
2526

2627
from ..conftest import IsNow, try_import
2728

2829
with try_import() as imports_successful:
29-
from anthropic import AsyncAnthropic
30+
from anthropic import NOT_GIVEN, AsyncAnthropic
3031
from anthropic.types import (
3132
ContentBlock,
3233
Message as AnthropicMessage,
@@ -53,6 +54,7 @@ def test_init():
5354
class MockAnthropic:
5455
messages_: AnthropicMessage | list[AnthropicMessage] | None = None
5556
index = 0
57+
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
5658

5759
@cached_property
5860
def messages(self) -> Any:
@@ -62,7 +64,9 @@ def messages(self) -> Any:
6264
def create_mock(cls, messages_: AnthropicMessage | list[AnthropicMessage]) -> AsyncAnthropic:
6365
return cast(AsyncAnthropic, cls(messages_=messages_))
6466

65-
async def messages_create(self, *_args: Any, **_kwargs: Any) -> AnthropicMessage:
67+
async def messages_create(self, *_args: Any, **kwargs: Any) -> AnthropicMessage:
68+
self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})
69+
6670
assert self.messages_ is not None, '`messages` must be provided'
6771
if isinstance(self.messages_, list):
6872
response = self.messages_[self.index]
@@ -257,3 +261,40 @@ async def get_location(loc_name: str) -> str:
257261
),
258262
]
259263
)
264+
265+
266+
def get_mock_chat_completion_kwargs(async_anthropic: AsyncAnthropic) -> list[dict[str, Any]]:
267+
if isinstance(async_anthropic, MockAnthropic):
268+
return async_anthropic.chat_completion_kwargs
269+
else: # pragma: no cover
270+
raise RuntimeError('Not a MockOpenAI instance')
271+
272+
273+
@pytest.mark.parametrize('parallel_tool_calls', [True, False])
274+
async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_calls: bool) -> None:
275+
responses = [
276+
completion_message(
277+
[ToolUseBlock(id='1', input={'loc_name': 'San Francisco'}, name='get_location', type='tool_use')],
278+
usage=AnthropicUsage(input_tokens=2, output_tokens=1),
279+
),
280+
completion_message(
281+
[TextBlock(text='final response', type='text')],
282+
usage=AnthropicUsage(input_tokens=3, output_tokens=5),
283+
),
284+
]
285+
286+
mock_client = MockAnthropic.create_mock(responses)
287+
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
288+
agent = Agent(m, model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls))
289+
290+
@agent.tool_plain
291+
async def get_location(loc_name: str) -> str:
292+
if loc_name == 'London':
293+
return json.dumps({'lat': 51, 'lng': 0})
294+
else:
295+
raise ModelRetry('Wrong location, please try again')
296+
297+
await agent.run('hello')
298+
assert get_mock_chat_completion_kwargs(mock_client)[0]['tool_choice']['disable_parallel_tool_use'] == (
299+
not parallel_tool_calls
300+
)

tests/models/test_openai.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
UserPromptPart,
2424
)
2525
from pydantic_ai.result import Usage
26+
from pydantic_ai.settings import ModelSettings
2627

2728
from ..conftest import IsNow, try_import
2829
from .mock_async_stream import MockAsyncStream
@@ -539,3 +540,26 @@ async def test_system_prompt_role(
539540
'n': 1,
540541
}
541542
]
543+
544+
545+
@pytest.mark.parametrize('parallel_tool_calls', [True, False])
546+
async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_calls: bool) -> None:
547+
c = completion_message(
548+
ChatCompletionMessage(
549+
content=None,
550+
role='assistant',
551+
tool_calls=[
552+
chat.ChatCompletionMessageToolCall(
553+
id='123',
554+
function=Function(arguments='{"response": [1, 2, 3]}', name='final_result'),
555+
type='function',
556+
)
557+
],
558+
)
559+
)
560+
mock_client = MockOpenAI.create_mock(c)
561+
m = OpenAIModel('gpt-4o', openai_client=mock_client)
562+
agent = Agent(m, result_type=list[int], model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls))
563+
564+
await agent.run('Hello')
565+
assert get_mock_chat_completion_kwargs(mock_client)[0]['parallel_tool_calls'] == parallel_tool_calls

0 commit comments

Comments
 (0)