Skip to content

Commit faa3868

Browse files
authored
Fix support for multiple MCPServerTool builtin_tools (#3239)
1 parent 3862a33 commit faa3868

File tree

9 files changed

+144
-176
lines changed

9 files changed

+144
-176
lines changed

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ async def main():
542542
"""
543543
if infer_name and self.name is None:
544544
self._infer_name(inspect.currentframe())
545+
545546
model_used = self._get_model(model)
546547
del model
547548

@@ -607,16 +608,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
607608
else:
608609
instrumentation_settings = None
609610
tracer = NoOpTracer()
610-
if builtin_tools:
611-
# Deduplicate builtin tools passed to the agent and the run based on type
612-
builtin_tools = list(
613-
{
614-
**({type(tool): tool for tool in self._builtin_tools or []}),
615-
**({type(tool): tool for tool in builtin_tools}),
616-
}.values()
617-
)
618-
else:
619-
builtin_tools = list(self._builtin_tools)
611+
620612
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
621613
user_deps=deps,
622614
prompt=user_prompt,
@@ -629,7 +621,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
629621
output_schema=output_schema,
630622
output_validators=output_validators,
631623
history_processors=self.history_processors,
632-
builtin_tools=builtin_tools,
624+
builtin_tools=[*self._builtin_tools, *(builtin_tools or [])],
633625
tool_manager=tool_manager,
634626
tracer=tracer,
635627
get_instructions=get_instructions,

pydantic_ai_slim/pydantic_ai/builtin_tools.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ class AbstractBuiltinTool(ABC):
3434
kind: str = 'unknown_builtin_tool'
3535
"""Built-in tool identifier, this should be available on all built-in tools as a discriminator."""
3636

37+
@property
38+
def unique_id(self) -> str:
39+
"""A unique identifier for the builtin tool.
40+
41+
If multiple instances of the same builtin tool can be passed to the model, subclasses should override this property to allow them to be distinguished.
42+
"""
43+
return self.kind
44+
3745
def __init_subclass__(cls, **kwargs: Any) -> None:
3846
super().__init_subclass__(**kwargs)
3947
_BUILTIN_TOOL_TYPES[cls.kind] = cls
@@ -275,7 +283,7 @@ class MCPServerTool(AbstractBuiltinTool):
275283
"""
276284

277285
id: str
278-
"""The ID of the MCP server."""
286+
"""A unique identifier for the MCP server."""
279287

280288
url: str
281289
"""The URL of the MCP server to use.
@@ -321,6 +329,10 @@ class MCPServerTool(AbstractBuiltinTool):
321329

322330
kind: str = 'mcp_server'
323331

332+
@property
333+
def unique_id(self) -> str:
334+
return ':'.join([self.kind, self.id])
335+
324336

325337
def _tool_discriminator(tool_data: dict[str, Any] | AbstractBuiltinTool) -> str:
326338
if isinstance(tool_data, dict):

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,17 @@ def prepare_request(
410410
they need to customize the preparation flow further, but most implementations should simply call
411411
``self.prepare_request(...)`` at the start of their ``request`` (and related) methods.
412412
"""
413-
merged_settings = merge_model_settings(self.settings, model_settings)
414-
customized_parameters = self.customize_request_parameters(model_request_parameters)
415-
return merged_settings, customized_parameters
413+
model_settings = merge_model_settings(self.settings, model_settings)
414+
415+
if builtin_tools := model_request_parameters.builtin_tools:
416+
# Deduplicate builtin tools
417+
model_request_parameters = replace(
418+
model_request_parameters,
419+
builtin_tools=list({tool.unique_id: tool for tool in builtin_tools}.values()),
420+
)
421+
422+
model_request_parameters = self.customize_request_parameters(model_request_parameters)
423+
return model_settings, model_request_parameters
416424

417425
@property
418426
@abstractmethod

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -324,11 +324,14 @@ async def _messages_create(
324324
def _process_response(self, response: BetaMessage) -> ModelResponse:
325325
"""Process a non-streamed response, and prepare a message to return."""
326326
items: list[ModelResponsePart] = []
327+
builtin_tool_calls: dict[str, BuiltinToolCallPart] = {}
327328
for item in response.content:
328329
if isinstance(item, BetaTextBlock):
329330
items.append(TextPart(content=item.text))
330331
elif isinstance(item, BetaServerToolUseBlock):
331-
items.append(_map_server_tool_use_block(item, self.system))
332+
call_part = _map_server_tool_use_block(item, self.system)
333+
builtin_tool_calls[call_part.tool_call_id] = call_part
334+
items.append(call_part)
332335
elif isinstance(item, BetaWebSearchToolResultBlock):
333336
items.append(_map_web_search_tool_result_block(item, self.system))
334337
elif isinstance(item, BetaCodeExecutionToolResultBlock):
@@ -340,9 +343,12 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
340343
elif isinstance(item, BetaThinkingBlock):
341344
items.append(ThinkingPart(content=item.thinking, signature=item.signature, provider_name=self.system))
342345
elif isinstance(item, BetaMCPToolUseBlock):
343-
items.append(_map_mcp_server_use_block(item, self.system))
346+
call_part = _map_mcp_server_use_block(item, self.system)
347+
builtin_tool_calls[call_part.tool_call_id] = call_part
348+
items.append(call_part)
344349
elif isinstance(item, BetaMCPToolResultBlock):
345-
items.append(_map_mcp_server_result_block(item, self.system))
350+
call_part = builtin_tool_calls.get(item.tool_use_id)
351+
items.append(_map_mcp_server_result_block(item, call_part, self.system))
346352
else:
347353
assert isinstance(item, BetaToolUseBlock), f'unexpected item type {type(item)}'
348354
items.append(
@@ -545,9 +551,9 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
545551
)
546552
assistant_content_params.append(server_tool_use_block_param)
547553
elif (
548-
response_part.tool_name == MCPServerTool.kind
554+
response_part.tool_name.startswith(MCPServerTool.kind)
555+
and (server_id := response_part.tool_name.split(':', 1)[1])
549556
and (args := response_part.args_as_dict())
550-
and (server_id := args.get('server_id'))
551557
and (tool_name := args.get('tool_name'))
552558
and (tool_args := args.get('tool_args'))
553559
): # pragma: no branch
@@ -590,7 +596,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
590596
),
591597
)
592598
)
593-
elif response_part.tool_name == MCPServerTool.kind and isinstance(
599+
elif response_part.tool_name.startswith(MCPServerTool.kind) and isinstance(
594600
response_part.content, dict
595601
): # pragma: no branch
596602
assistant_content_params.append(
@@ -714,6 +720,7 @@ class AnthropicStreamedResponse(StreamedResponse):
714720
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
715721
current_block: BetaContentBlock | None = None
716722

723+
builtin_tool_calls: dict[str, BuiltinToolCallPart] = {}
717724
async for event in self._response:
718725
if isinstance(event, BetaRawMessageStartEvent):
719726
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
@@ -751,9 +758,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
751758
if maybe_event is not None: # pragma: no branch
752759
yield maybe_event
753760
elif isinstance(current_block, BetaServerToolUseBlock):
761+
call_part = _map_server_tool_use_block(current_block, self.provider_name)
762+
builtin_tool_calls[call_part.tool_call_id] = call_part
754763
yield self._parts_manager.handle_part(
755764
vendor_part_id=event.index,
756-
part=_map_server_tool_use_block(current_block, self.provider_name),
765+
part=call_part,
757766
)
758767
elif isinstance(current_block, BetaWebSearchToolResultBlock):
759768
yield self._parts_manager.handle_part(
@@ -767,6 +776,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
767776
)
768777
elif isinstance(current_block, BetaMCPToolUseBlock):
769778
call_part = _map_mcp_server_use_block(current_block, self.provider_name)
779+
builtin_tool_calls[call_part.tool_call_id] = call_part
770780

771781
args_json = call_part.args_as_json_str()
772782
# Drop the final `{}}` so that we can add tool args deltas
@@ -785,9 +795,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
785795
if maybe_event is not None: # pragma: no branch
786796
yield maybe_event
787797
elif isinstance(current_block, BetaMCPToolResultBlock):
798+
call_part = builtin_tool_calls.get(current_block.tool_use_id)
788799
yield self._parts_manager.handle_part(
789800
vendor_part_id=event.index,
790-
part=_map_mcp_server_result_block(current_block, self.provider_name),
801+
part=_map_mcp_server_result_block(current_block, call_part, self.provider_name),
791802
)
792803

793804
elif isinstance(event, BetaRawContentBlockDeltaEvent):
@@ -908,21 +919,22 @@ def _map_code_execution_tool_result_block(
908919
def _map_mcp_server_use_block(item: BetaMCPToolUseBlock, provider_name: str) -> BuiltinToolCallPart:
909920
return BuiltinToolCallPart(
910921
provider_name=provider_name,
911-
tool_name=MCPServerTool.kind,
922+
tool_name=':'.join([MCPServerTool.kind, item.server_name]),
912923
args={
913924
'action': 'call_tool',
914-
'server_id': item.server_name,
915925
'tool_name': item.name,
916926
'tool_args': cast(dict[str, Any], item.input),
917927
},
918928
tool_call_id=item.id,
919929
)
920930

921931

922-
def _map_mcp_server_result_block(item: BetaMCPToolResultBlock, provider_name: str) -> BuiltinToolReturnPart:
932+
def _map_mcp_server_result_block(
933+
item: BetaMCPToolResultBlock, call_part: BuiltinToolCallPart | None, provider_name: str
934+
) -> BuiltinToolReturnPart:
923935
return BuiltinToolReturnPart(
924936
provider_name=provider_name,
925-
tool_name=MCPServerTool.kind,
937+
tool_name=call_part.tool_name if call_part else MCPServerTool.kind,
926938
content=item.model_dump(mode='json', include={'content', 'is_error'}),
927939
tool_call_id=item.tool_use_id,
928940
)

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,11 +1474,11 @@ async def _map_messages( # noqa: C901
14741474
)
14751475
openai_messages.append(image_generation_item)
14761476
elif ( # pragma: no branch
1477-
item.tool_name == MCPServerTool.kind
1477+
item.tool_name.startswith(MCPServerTool.kind)
14781478
and item.tool_call_id
1479+
and (server_id := item.tool_name.split(':', 1)[1])
14791480
and (args := item.args_as_dict())
14801481
and (action := args.get('action'))
1481-
and (server_id := args.get('server_id'))
14821482
):
14831483
if action == 'list_tools':
14841484
mcp_list_tools_item = responses.response_input_item_param.McpListTools(
@@ -1525,7 +1525,7 @@ async def _map_messages( # noqa: C901
15251525
elif item.tool_name == ImageGenerationTool.kind:
15261526
# Image generation result does not need to be sent back, just the `id` off of `BuiltinToolCallPart`.
15271527
pass
1528-
elif item.tool_name == MCPServerTool.kind: # pragma: no branch
1528+
elif item.tool_name.startswith(MCPServerTool.kind): # pragma: no branch
15291529
# MCP call result does not need to be sent back, just the fields off of `BuiltinToolCallPart`.
15301530
pass
15311531
elif isinstance(item, FilePart):
@@ -2257,15 +2257,16 @@ def _map_image_generation_tool_call(
22572257
def _map_mcp_list_tools(
22582258
item: responses.response_output_item.McpListTools, provider_name: str
22592259
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
2260+
tool_name = ':'.join([MCPServerTool.kind, item.server_label])
22602261
return (
22612262
BuiltinToolCallPart(
2262-
tool_name=MCPServerTool.kind,
2263+
tool_name=tool_name,
22632264
tool_call_id=item.id,
22642265
provider_name=provider_name,
2265-
args={'action': 'list_tools', 'server_id': item.server_label},
2266+
args={'action': 'list_tools'},
22662267
),
22672268
BuiltinToolReturnPart(
2268-
tool_name=MCPServerTool.kind,
2269+
tool_name=tool_name,
22692270
tool_call_id=item.id,
22702271
content=item.model_dump(mode='json', include={'tools', 'error'}),
22712272
provider_name=provider_name,
@@ -2276,20 +2277,20 @@ def _map_mcp_list_tools(
22762277
def _map_mcp_call(
22772278
item: responses.response_output_item.McpCall, provider_name: str
22782279
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
2280+
tool_name = ':'.join([MCPServerTool.kind, item.server_label])
22792281
return (
22802282
BuiltinToolCallPart(
2281-
tool_name=MCPServerTool.kind,
2283+
tool_name=tool_name,
22822284
tool_call_id=item.id,
22832285
args={
22842286
'action': 'call_tool',
2285-
'server_id': item.server_label,
22862287
'tool_name': item.name,
22872288
'tool_args': json.loads(item.arguments) if item.arguments else {},
22882289
},
22892290
provider_name=provider_name,
22902291
),
22912292
BuiltinToolReturnPart(
2292-
tool_name=MCPServerTool.kind,
2293+
tool_name=tool_name,
22932294
tool_call_id=item.id,
22942295
content={
22952296
'output': item.output,

tests/models/test_anthropic.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3098,10 +3098,9 @@ async def test_anthropic_mcp_servers(allow_model_requests: None, anthropic_api_k
30983098
provider_name='anthropic',
30993099
),
31003100
BuiltinToolCallPart(
3101-
tool_name='mcp_server',
3101+
tool_name='mcp_server:deepwiki',
31023102
args={
31033103
'action': 'call_tool',
3104-
'server_id': 'deepwiki',
31053104
'tool_name': 'ask_question',
31063105
'tool_args': {
31073106
'repoName': 'pydantic/pydantic-ai',
@@ -3112,7 +3111,7 @@ async def test_anthropic_mcp_servers(allow_model_requests: None, anthropic_api_k
31123111
provider_name='anthropic',
31133112
),
31143113
BuiltinToolReturnPart(
3115-
tool_name='mcp_server',
3114+
tool_name='mcp_server:deepwiki',
31163115
content={
31173116
'content': [
31183117
{
@@ -3181,10 +3180,9 @@ async def test_anthropic_mcp_servers(allow_model_requests: None, anthropic_api_k
31813180
provider_name='anthropic',
31823181
),
31833182
BuiltinToolCallPart(
3184-
tool_name='mcp_server',
3183+
tool_name='mcp_server:deepwiki',
31853184
args={
31863185
'action': 'call_tool',
3187-
'server_id': 'deepwiki',
31883186
'tool_name': 'ask_question',
31893187
'tool_args': {
31903188
'repoName': 'pydantic/pydantic',
@@ -3195,7 +3193,7 @@ async def test_anthropic_mcp_servers(allow_model_requests: None, anthropic_api_k
31953193
provider_name='anthropic',
31963194
),
31973195
BuiltinToolReturnPart(
3198-
tool_name='mcp_server',
3196+
tool_name='mcp_server:deepwiki',
31993197
content={
32003198
'content': [
32013199
{
@@ -3345,13 +3343,13 @@ async def test_anthropic_mcp_servers_stream(allow_model_requests: None, anthropi
33453343
provider_name='anthropic',
33463344
),
33473345
BuiltinToolCallPart(
3348-
tool_name='mcp_server',
3349-
args='{"action":"call_tool","server_id":"deepwiki","tool_name":"ask_question","tool_args":{"repoName": "pydantic/pydantic-ai", "question": "What is this repository about? What are its main features and purpose?"}}',
3346+
tool_name='mcp_server:deepwiki',
3347+
args='{"action":"call_tool","tool_name":"ask_question","tool_args":{"repoName": "pydantic/pydantic-ai", "question": "What is this repository about? What are its main features and purpose?"}}',
33503348
tool_call_id='mcptoolu_01FZmJ5UspaX5BB9uU339UT1',
33513349
provider_name='anthropic',
33523350
),
33533351
BuiltinToolReturnPart(
3354-
tool_name='mcp_server',
3352+
tool_name='mcp_server:deepwiki',
33553353
content={
33563354
'content': [
33573355
{
@@ -3407,15 +3405,15 @@ async def test_anthropic_mcp_servers_stream(allow_model_requests: None, anthropi
34073405
PartStartEvent(
34083406
index=1,
34093407
part=BuiltinToolCallPart(
3410-
tool_name='mcp_server',
3408+
tool_name='mcp_server:deepwiki',
34113409
tool_call_id='mcptoolu_01FZmJ5UspaX5BB9uU339UT1',
34123410
provider_name='anthropic',
34133411
),
34143412
),
34153413
PartDeltaEvent(
34163414
index=1,
34173415
delta=ToolCallPartDelta(
3418-
args_delta='{"action":"call_tool","server_id":"deepwiki","tool_name":"ask_question","tool_args":',
3416+
args_delta='{"action":"call_tool","tool_name":"ask_question","tool_args":',
34193417
tool_call_id='mcptoolu_01FZmJ5UspaX5BB9uU339UT1',
34203418
),
34213419
),
@@ -3489,7 +3487,7 @@ async def test_anthropic_mcp_servers_stream(allow_model_requests: None, anthropi
34893487
PartStartEvent(
34903488
index=2,
34913489
part=BuiltinToolReturnPart(
3492-
tool_name='mcp_server',
3490+
tool_name='mcp_server:deepwiki',
34933491
content={
34943492
'content': [
34953493
{

0 commit comments

Comments
 (0)