Skip to content

Commit 2606315

Browse files
committed
Adding streaming support and test coverage
1 parent bd844f3 commit 2606315

File tree

6 files changed

+626
-271
lines changed

6 files changed

+626
-271
lines changed

pydantic_ai_slim/pydantic_ai/builtin_tools.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,16 +243,14 @@ class MCPServerTool(AbstractBuiltinTool):
243243
Supported by:
244244
245245
* OpenAI Responses
246+
* Anthropic
246247
"""
247248

248249
kind: str = 'mcp_server'
249250

250251
server_label: str
251252
"""The label of the MCP server to use."""
252253

253-
require_approval: Literal['never', 'always'] | None = None
254-
"""Whether to require approval before using the MCP server."""
255-
256254
server_url: str
257255
"""The URL of the MCP server to use."""
258256

@@ -272,4 +270,9 @@ class MCPServerTool(AbstractBuiltinTool):
272270
"""
273271

274272
connector_id: str | None = None
275-
"""The ID of the connector to use."""
273+
"""The ID of the connector to use.
274+
275+
Supported by:
276+
277+
* OpenAI Responses
278+
"""

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
1414
from .._run_context import RunContext
1515
from .._utils import guard_tool_call_id as _guard_tool_call_id
16-
from ..builtin_tools import CodeExecutionTool, MemoryTool, WebSearchTool
16+
from ..builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebSearchTool
1717
from ..exceptions import UserError
1818
from ..messages import (
1919
BinaryContent,
@@ -82,6 +82,8 @@
8282
BetaRawMessageStreamEvent,
8383
BetaRedactedThinkingBlock,
8484
BetaRedactedThinkingBlockParam,
85+
BetaRequestMCPServerToolConfigurationParam,
86+
BetaRequestMCPServerURLDefinitionParam,
8587
BetaServerToolUseBlock,
8688
BetaServerToolUseBlockParam,
8789
BetaSignatureDelta,
@@ -265,6 +267,7 @@ async def _messages_create(
265267
# standalone function to make it easier to override
266268
tools = self._get_tools(model_request_parameters)
267269
tools, beta_features = self._add_builtin_tools(tools, model_request_parameters)
270+
mcp_servers = self._get_mcp_servers(model_request_parameters)
268271

269272
tool_choice: BetaToolChoiceParam | None
270273

@@ -300,6 +303,7 @@ async def _messages_create(
300303
model=self._model_name,
301304
tools=tools or OMIT,
302305
tool_choice=tool_choice or OMIT,
306+
mcp_servers=mcp_servers or OMIT,
303307
stream=stream,
304308
thinking=model_settings.get('anthropic_thinking', OMIT),
305309
stop_sequences=model_settings.get('stop_sequences', OMIT),
@@ -407,12 +411,40 @@ def _add_builtin_tools(
407411
tools = [tool for tool in tools if tool['name'] != 'memory']
408412
tools.append(BetaMemoryTool20250818Param(name='memory', type='memory_20250818'))
409413
beta_features.append('context-management-2025-06-27')
414+
elif isinstance(tool, MCPServerTool):
415+
# Anthropic MCP servers are a separate parameter in the API call
416+
pass
410417
else: # pragma: no cover
411418
raise UserError(
412419
f'`{tool.__class__.__name__}` is not supported by `AnthropicModel`. If it should be, please file an issue.'
413420
)
414421
return tools, beta_features
415422

423+
def _get_mcp_servers(
424+
self, model_request_parameters: ModelRequestParameters
425+
) -> list[BetaRequestMCPServerURLDefinitionParam]:
426+
mcp_servers: list[BetaRequestMCPServerURLDefinitionParam] = []
427+
for tool in model_request_parameters.builtin_tools:
428+
if isinstance(tool, MCPServerTool):
429+
tool_configuration = (
430+
BetaRequestMCPServerToolConfigurationParam(
431+
enabled=True,
432+
allowed_tools=tool.allowed_tools,
433+
)
434+
if tool.allowed_tools
435+
else None
436+
)
437+
mcp_servers.append(
438+
BetaRequestMCPServerURLDefinitionParam(
439+
type='url',
440+
name=tool.server_label,
441+
url=tool.server_url,
442+
authorization_token=tool.authorization,
443+
tool_configuration=tool_configuration,
444+
)
445+
)
446+
return mcp_servers
447+
416448
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901
417449
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
418450
system_prompt_parts: list[str] = []

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 119 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,13 +1032,16 @@ def _process_response( # noqa: C901
10321032
elif isinstance(item, responses.ResponseFileSearchToolCall): # pragma: no cover
10331033
# Pydantic AI doesn't yet support the FileSearch built-in tool
10341034
pass
1035-
elif isinstance( # pragma: no cover
1036-
item,
1037-
responses.response_output_item.McpCall
1038-
| responses.response_output_item.McpListTools
1039-
| responses.response_output_item.McpApprovalRequest,
1040-
):
1041-
# Pydantic AI supports MCP natively
1035+
elif isinstance(item, responses.response_output_item.McpCall):
1036+
call_part, return_part = _map_mcp_call(item, self.system)
1037+
items.append(call_part)
1038+
items.append(return_part)
1039+
elif isinstance(item, responses.response_output_item.McpListTools):
1040+
call_part, return_part = _map_mcp_list_tools(item, self.system)
1041+
items.append(call_part)
1042+
items.append(return_part)
1043+
elif isinstance(item, responses.response_output_item.McpApprovalRequest):
1044+
# Pydantic AI doesn't yet support McpApprovalRequest (explicit tool usage approval)
10421045
pass
10431046

10441047
finish_reason: FinishReason | None = None
@@ -1235,7 +1238,7 @@ def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -
12351238
server_description=tool.server_description,
12361239
allowed_tools=tool.allowed_tools,
12371240
authorization=tool.authorization,
1238-
require_approval=tool.require_approval,
1241+
require_approval='never',
12391242
headers=tool.headers,
12401243
)
12411244
)
@@ -1760,7 +1763,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
17601763
elif isinstance(chunk.item, responses.response_output_item.ImageGenerationCall):
17611764
call_part, _, _ = _map_image_generation_tool_call(chunk.item, self.provider_name)
17621765
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-call', part=call_part)
1763-
1766+
elif isinstance(chunk.item, responses.response_output_item.McpCall):
1767+
call_part, _ = _map_mcp_call(chunk.item, self.provider_name)
1768+
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-call', part=call_part)
1769+
elif isinstance(chunk.item, responses.response_output_item.McpListTools):
1770+
call_part, _ = _map_mcp_list_tools(chunk.item, self.provider_name)
1771+
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-call', part=call_part)
17641772
else:
17651773
warnings.warn( # pragma: no cover
17661774
f'Handling of this item type is not yet implemented. Please report on our GitHub: {chunk}',
@@ -1801,6 +1809,25 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
18011809
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-file', part=file_part)
18021810
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
18031811

1812+
elif isinstance(chunk.item, responses.response_output_item.McpCall):
1813+
call_part, return_part = _map_mcp_call(chunk.item, self.provider_name)
1814+
maybe_event = self._parts_manager.handle_tool_call_delta(
1815+
vendor_part_id=f'{chunk.item.id}-call',
1816+
args=call_part.args,
1817+
)
1818+
if maybe_event is not None: # pragma: no branch
1819+
yield maybe_event
1820+
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
1821+
elif isinstance(chunk.item, responses.response_output_item.McpListTools):
1822+
call_part, return_part = _map_mcp_list_tools(chunk.item, self.provider_name)
1823+
maybe_event = self._parts_manager.handle_tool_call_delta(
1824+
vendor_part_id=f'{chunk.item.id}-call',
1825+
args=call_part.args,
1826+
)
1827+
if maybe_event is not None: # pragma: no branch
1828+
yield maybe_event
1829+
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
1830+
18041831
elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
18051832
yield self._parts_manager.handle_thinking_delta(
18061833
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
@@ -1895,6 +1922,33 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
18951922
)
18961923
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item_id}-file', part=file_part)
18971924

1925+
elif isinstance(chunk, responses.ResponseMcpCallArgumentsDeltaEvent):
1926+
maybe_event = self._parts_manager.handle_text_delta(
1927+
vendor_part_id=chunk.item_id, content=chunk.delta, id=chunk.item_id
1928+
)
1929+
if maybe_event is not None: # pragma: no branch
1930+
yield maybe_event
1931+
elif isinstance(chunk, responses.ResponseMcpCallArgumentsDoneEvent):
1932+
pass # there's nothing we need to do here
1933+
1934+
elif isinstance(chunk, responses.ResponseMcpListToolsInProgressEvent):
1935+
pass # there's nothing we need to do here
1936+
1937+
elif isinstance(chunk, responses.ResponseMcpListToolsCompletedEvent):
1938+
pass # there's nothing we need to do here
1939+
1940+
elif isinstance(chunk, responses.ResponseMcpListToolsFailedEvent):
1941+
pass # there's nothing we need to do here
1942+
1943+
elif isinstance(chunk, responses.ResponseMcpCallInProgressEvent):
1944+
pass # there's nothing we need to do here
1945+
1946+
elif isinstance(chunk, responses.ResponseMcpCallFailedEvent):
1947+
pass # there's nothing we need to do here
1948+
1949+
elif isinstance(chunk, responses.ResponseMcpCallCompletedEvent):
1950+
pass # there's nothing we need to do here
1951+
18981952
else: # pragma: no cover
18991953
warnings.warn(
19001954
f'Handling of this event type is not yet implemented. Please report on our GitHub: {chunk}',
@@ -2106,3 +2160,59 @@ def _map_image_generation_tool_call(
21062160
),
21072161
file_part,
21082162
)
2163+
2164+
2165+
def _map_mcp_list_tools(
2166+
item: responses.response_output_item.McpListTools, provider_name: str
2167+
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
2168+
result = {
2169+
'server_label': item.server_label,
2170+
'tools': [
2171+
{
2172+
'name': tool.name,
2173+
'description': tool.description,
2174+
'input_schema': tool.input_schema,
2175+
}
2176+
for tool in item.tools
2177+
],
2178+
}
2179+
2180+
return (
2181+
BuiltinToolCallPart(
2182+
tool_name=MCPServerTool.kind,
2183+
tool_call_id=item.id,
2184+
provider_name=provider_name,
2185+
),
2186+
BuiltinToolReturnPart(
2187+
tool_name=MCPServerTool.kind,
2188+
tool_call_id=item.id,
2189+
content=result,
2190+
provider_name=provider_name,
2191+
),
2192+
)
2193+
2194+
2195+
def _map_mcp_call(
2196+
item: responses.response_output_item.McpCall, provider_name: str
2197+
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
2198+
result = {
2199+
'name': item.name,
2200+
'server_label': item.server_label,
2201+
'arguments': item.arguments,
2202+
'error': item.error,
2203+
'output': item.output,
2204+
}
2205+
2206+
return (
2207+
BuiltinToolCallPart(
2208+
tool_name=MCPServerTool.kind,
2209+
tool_call_id=item.id,
2210+
provider_name=provider_name,
2211+
),
2212+
BuiltinToolReturnPart(
2213+
tool_name=MCPServerTool.kind,
2214+
tool_call_id=item.id,
2215+
content=result,
2216+
provider_name=provider_name,
2217+
),
2218+
)

0 commit comments

Comments
 (0)