Skip to content

Commit 4f0f5e9

Browse files
committed
Addressing comments
1 parent f0e0c38 commit 4f0f5e9

File tree

6 files changed

+1195
-168
lines changed

6 files changed

+1195
-168
lines changed

pydantic_ai_slim/pydantic_ai/builtin_tools.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ class MCPServerTool(AbstractBuiltinTool):
250250
"""
251251

252252
kind: str = 'mcp_server'
253+
list_tools_kind: str = 'mcp_list_tools'
254+
call_kind: str = 'mcp_call'
253255

254256
label: str
255257
"""The label of the MCP server to use."""

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
544544
input=response_part.args_as_dict(),
545545
)
546546
assistant_content_params.append(server_tool_use_block_param)
547-
elif response_part.tool_name == MCPServerTool.kind: # pragma: no branch
547+
elif response_part.tool_name == MCPServerTool.kind: # pragma: no cover
548548
mcp_tool_use_block_param = BetaMCPToolUseBlockParam(
549549
id=tool_use_id,
550550
type='mcp_tool_use',
@@ -584,7 +584,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be
584584
),
585585
)
586586
)
587-
elif response_part.tool_name in ( # pragma: no branch
587+
elif response_part.tool_name in ( # pragma: no cover
588588
MCPServerTool.kind,
589589
'mcp_tool_result', # Backward compatibility
590590
) and isinstance(response_part.content, str | list):

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,9 +1242,9 @@ def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -
12421242
require_approval='never',
12431243
headers=tool.headers,
12441244
)
1245-
if tool.url:
1245+
if tool.url: # pragma: no cover
12461246
mcp_tool['server_url'] = tool.url
1247-
if tool.connector_id:
1247+
elif tool.connector_id: # pragma: no cover
12481248
mcp_tool['connector_id'] = tool.connector_id
12491249
tools.append(mcp_tool)
12501250
elif isinstance(tool, ImageGenerationTool): # pragma: no branch
@@ -1428,32 +1428,36 @@ async def _map_messages( # noqa: C901
14281428
},
14291429
)
14301430
openai_messages.append(image_generation_item)
1431-
elif (
1432-
item.tool_name == MCPServerTool.kind
1431+
elif ( # pragma: no cover
1432+
item.tool_name == MCPServerTool.list_tools_kind
14331433
and item.tool_call_id
14341434
and (args := item.args_as_dict())
1435-
and (server_label := args.get('server_label'))
1436-
): # pragma: no branch
1437-
if tools := args.get('tools'):
1438-
mcp_list_tools_item = responses.response_input_item_param.McpListTools(
1439-
id=item.tool_call_id,
1440-
tools=cast(list[responses.response_input_item_param.McpListToolsTool], tools),
1441-
server_label=server_label,
1442-
error=args.get('error'),
1443-
type='mcp_list_tools',
1444-
)
1445-
openai_messages.append(mcp_list_tools_item)
1446-
elif (arguments := args.get('arguments')) and (name := args.get('name')):
1447-
mcp_call_item = responses.response_input_item_param.McpCall(
1448-
id=item.tool_call_id,
1449-
name=name,
1450-
arguments=arguments,
1451-
server_label=server_label,
1452-
error=cast(str | None, args.get('error')),
1453-
output=cast(str | None, args.get('output')),
1454-
type='mcp_call',
1455-
)
1456-
openai_messages.append(mcp_call_item)
1435+
):
1436+
mcp_list_tools_item = responses.response_input_item_param.McpListTools(
1437+
id=item.tool_call_id,
1438+
tools=cast(
1439+
list[responses.response_input_item_param.McpListToolsTool], args.get('tools')
1440+
),
1441+
server_label=cast(str, args.get('server_label')),
1442+
error=args.get('error'),
1443+
type='mcp_list_tools',
1444+
)
1445+
openai_messages.append(mcp_list_tools_item)
1446+
elif ( # pragma: no cover
1447+
item.tool_name == MCPServerTool.call_kind
1448+
and item.tool_call_id
1449+
and (args := item.args_as_dict())
1450+
):
1451+
mcp_call_item = responses.response_input_item_param.McpCall(
1452+
id=item.tool_call_id,
1453+
name=cast(str, args.get('name')),
1454+
arguments=cast(str, args.get('arguments')),
1455+
server_label=cast(str, args.get('server_label')),
1456+
error=cast(str | None, args.get('error')),
1457+
output=cast(str | None, args.get('output')),
1458+
type='mcp_call',
1459+
)
1460+
openai_messages.append(mcp_call_item)
14571461

14581462
elif isinstance(item, BuiltinToolReturnPart):
14591463
if item.provider_name == self.system and send_item_ids:
@@ -1473,11 +1477,14 @@ async def _map_messages( # noqa: C901
14731477
and (status := content.get('status'))
14741478
):
14751479
web_search_item['status'] = status
1476-
elif item.tool_name == ImageGenerationTool.kind: # pragma: no branch
1480+
elif item.tool_name == ImageGenerationTool.kind: # pragma: no cover
14771481
# Image generation result does not need to be sent back, just the `id` off of `BuiltinToolCallPart`.
14781482
pass
1479-
elif item.tool_name == MCPServerTool.kind: # pragma: no branch
1480-
# MCP result does not need to be sent back, just the fields off of `BuiltinToolCallPart`.
1483+
elif item.tool_name == MCPServerTool.list_tools_kind: # pragma: no cover
1484+
# MCP list result does not need to be sent back, just the fields off of `BuiltinToolCallPart`.
1485+
pass
1486+
elif item.tool_name == MCPServerTool.call_kind: # pragma: no cover
1487+
# MCP call result does not need to be sent back, just the fields off of `BuiltinToolCallPart`.
14811488
pass
14821489
elif isinstance(item, FilePart):
14831490
# This was generated by the `ImageGenerationTool` or `CodeExecutionTool`,
@@ -2202,13 +2209,13 @@ def _map_mcp_list_tools(
22022209

22032210
return (
22042211
BuiltinToolCallPart(
2205-
tool_name=MCPServerTool.kind,
2212+
tool_name=MCPServerTool.list_tools_kind,
22062213
tool_call_id=item.id,
22072214
args=item_serialized,
22082215
provider_name=provider_name,
22092216
),
22102217
BuiltinToolReturnPart(
2211-
tool_name=MCPServerTool.kind,
2218+
tool_name=MCPServerTool.list_tools_kind,
22122219
tool_call_id=item.id,
22132220
content=item_serialized,
22142221
provider_name=provider_name,
@@ -2226,13 +2233,13 @@ def _map_mcp_call(
22262233

22272234
return (
22282235
BuiltinToolCallPart(
2229-
tool_name=MCPServerTool.kind,
2236+
tool_name=MCPServerTool.call_kind,
22302237
tool_call_id=item.id,
22312238
args=item_serialized,
22322239
provider_name=provider_name,
22332240
),
22342241
BuiltinToolReturnPart(
2235-
tool_name=MCPServerTool.kind,
2242+
tool_name=MCPServerTool.call_kind,
22362243
tool_call_id=item.id,
22372244
content=item_serialized,
22382245
provider_name=provider_name,

0 commit comments

Comments
 (0)