Skip to content

Commit cffc38f

Browse files
committed
Add FastMCP Toolset w/o tests
1 parent e30c444 commit cffc38f

File tree

2 files changed

+259
-0
lines changed

2 files changed

+259
-0
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
import contextlib
5+
from asyncio import Lock
6+
from contextlib import AsyncExitStack
7+
from enum import Enum
8+
from typing import TYPE_CHECKING, Any, Self
9+
10+
import pydantic_core
11+
from mcp.types import (
12+
AudioContent,
13+
ContentBlock,
14+
EmbeddedResource,
15+
ImageContent,
16+
TextContent,
17+
TextResourceContents,
18+
Tool as MCPTool,
19+
)
20+
21+
from pydantic_ai.exceptions import ModelRetry
22+
from pydantic_ai.mcp import TOOL_SCHEMA_VALIDATOR, messages
23+
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
24+
from pydantic_ai.toolsets import AbstractToolset
25+
from pydantic_ai.toolsets.abstract import ToolsetTool
26+
27+
try:
28+
from fastmcp.client import Client
29+
from fastmcp.client.transports import MCPConfigTransport
30+
from fastmcp.exceptions import ToolError
31+
from fastmcp.mcp_config import MCPConfig
32+
from fastmcp.server.server import FastMCP
33+
except ImportError as _import_error:
34+
raise ImportError(
35+
'Please install the `fastmcp` package to use the FastMCP server, '
36+
'you can use the `fastmcp` optional group — `pip install "pydantic-ai-slim[fastmcp]"`'
37+
) from _import_error
38+
39+
40+
if TYPE_CHECKING:
41+
from fastmcp import FastMCP
42+
from fastmcp.client.client import CallToolResult
43+
from fastmcp.client.transports import FastMCPTransport
44+
from fastmcp.mcp_config import MCPServerTypes
45+
46+
47+
FastMCPToolResult = messages.BinaryContent | dict[str, Any] | str | None
48+
49+
FastMCPToolResults = list[FastMCPToolResult] | FastMCPToolResult
50+
51+
52+
class ToolErrorBehavior(str, Enum):
53+
"""The behavior to take when a tool error occurs."""
54+
55+
MODEL_RETRY = 'model-retry'
56+
"""Raise a `ModelRetry` containing the tool error message."""
57+
58+
ERROR = 'raise'
59+
"""Raise the tool error as an exception."""
60+
61+
62+
class FastMCPToolset(AbstractToolset[AgentDepsT]):
63+
"""A toolset that uses a FastMCP client as the underlying toolset."""
64+
65+
_fastmcp_client: Client[Any] | None = None
66+
_tool_error_behavior: ToolErrorBehavior
67+
68+
_tool_retries: int
69+
70+
_enter_lock: Lock
71+
_running_count: int
72+
_exit_stack: AsyncExitStack | None
73+
74+
def __init__(
75+
self, fastmcp_client: Client[Any], tool_retries: int = 2, tool_error_behavior: ToolErrorBehavior | None = None
76+
):
77+
self._tool_retries = tool_retries
78+
self._fastmcp_client = fastmcp_client
79+
self._enter_lock = Lock()
80+
self._running_count = 0
81+
82+
self._tool_error_behavior = tool_error_behavior or ToolErrorBehavior.ERROR
83+
84+
super().__init__()
85+
86+
@property
87+
def id(self) -> str | None:
88+
return None
89+
90+
async def __aenter__(self) -> Self:
91+
async with self._enter_lock:
92+
if self._running_count == 0 and self._fastmcp_client:
93+
self._exit_stack = AsyncExitStack()
94+
await self._exit_stack.enter_async_context(self._fastmcp_client)
95+
self._running_count += 1
96+
97+
return self
98+
99+
async def __aexit__(self, *args: Any) -> bool | None:
100+
async with self._enter_lock:
101+
self._running_count -= 1
102+
if self._running_count == 0 and self._exit_stack:
103+
await self._exit_stack.aclose()
104+
self._exit_stack = None
105+
106+
return None
107+
108+
@property
109+
def fastmcp_client(self) -> Client[FastMCPTransport]:
110+
if not self._fastmcp_client:
111+
msg = 'FastMCP client not initialized'
112+
raise RuntimeError(msg)
113+
114+
return self._fastmcp_client
115+
116+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
117+
mcp_tools: list[MCPTool] = await self.fastmcp_client.list_tools()
118+
119+
return {
120+
tool.name: convert_mcp_tool_to_toolset_tool(toolset=self, mcp_tool=tool, retries=self._tool_retries)
121+
for tool in mcp_tools
122+
}
123+
124+
async def call_tool(
125+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT]
126+
) -> Any:
127+
try:
128+
call_tool_result: CallToolResult = await self.fastmcp_client.call_tool(name=name, arguments=tool_args)
129+
except ToolError as e:
130+
if self._tool_error_behavior == ToolErrorBehavior.MODEL_RETRY:
131+
raise ModelRetry(message=str(object=e)) from e
132+
else:
133+
raise e
134+
135+
# We don't use call_tool_result.data at the moment because it requires the json schema to be translatable
136+
# back into pydantic models otherwise it will be missing data.
137+
138+
return call_tool_result.structured_content or _map_fastmcp_tool_results(parts=call_tool_result.content)
139+
140+
@classmethod
141+
def from_fastmcp_server(
142+
cls, fastmcp_server: FastMCP[Any], tool_error_behavior: ToolErrorBehavior | None = None
143+
) -> Self:
144+
"""Build a FastMCPToolset from a FastMCP server.
145+
146+
Example:
147+
```python
148+
fastmcp_server = FastMCP('my_server')
149+
@fastmcp_server.tool()
150+
async def my_tool(a: int, b: int) -> int:
151+
return a + b
152+
153+
toolset = FastMCPToolset.from_fastmcp_server(fastmcp_server=fastmcp_server)
154+
```
155+
"""
156+
fastmcp_client: Client[FastMCPTransport] = Client[FastMCPTransport](transport=fastmcp_server)
157+
return cls(fastmcp_client=fastmcp_client, tool_retries=2, tool_error_behavior=tool_error_behavior)
158+
159+
@classmethod
160+
def from_mcp_server(
161+
cls,
162+
name: str,
163+
mcp_server: MCPServerTypes | dict[str, Any],
164+
tool_error_behavior: ToolErrorBehavior | None = None,
165+
) -> Self:
166+
"""Build a FastMCPToolset from an individual MCP server configuration.
167+
168+
Example:
169+
```python
170+
cls.from_mcp_server(name='my_server', mcp_server={
171+
'cmd': 'uvx',
172+
'args': [
173+
"time-server-mcp",
174+
]
175+
})
176+
```
177+
"""
178+
mcp_config: MCPConfig = MCPConfig.from_dict(config={name: mcp_server})
179+
180+
return cls.from_mcp_config(mcp_config=mcp_config, tool_error_behavior=tool_error_behavior)
181+
182+
@classmethod
183+
def from_mcp_config(
184+
cls, mcp_config: MCPConfig | dict[str, Any], tool_error_behavior: ToolErrorBehavior | None = None
185+
) -> Self:
186+
"""Build a FastMCPToolset from an MCP json-derived / dictionary configuration object.
187+
188+
Example:
189+
```python
190+
cls.from_mcp_config(mcp_config={
191+
'mcpServers': {
192+
'my_server': {
193+
'cmd': 'uvx',
194+
'args': [
195+
"time-server-mcp",
196+
]
197+
}
198+
}
199+
})
200+
```
201+
"""
202+
fastmcp_client: Client[MCPConfigTransport] = Client[MCPConfigTransport](transport=mcp_config)
203+
return cls(fastmcp_client=fastmcp_client, tool_retries=2, tool_error_behavior=tool_error_behavior)
204+
205+
206+
def convert_mcp_tool_to_toolset_tool(
207+
toolset: FastMCPToolset[AgentDepsT],
208+
mcp_tool: MCPTool,
209+
retries: int,
210+
) -> ToolsetTool[AgentDepsT]:
211+
"""Convert an MCP tool to a toolset tool."""
212+
return ToolsetTool[AgentDepsT](
213+
tool_def=ToolDefinition(
214+
name=mcp_tool.name,
215+
description=mcp_tool.description,
216+
parameters_json_schema=mcp_tool.inputSchema,
217+
),
218+
toolset=toolset,
219+
max_retries=retries,
220+
args_validator=TOOL_SCHEMA_VALIDATOR,
221+
)
222+
223+
224+
def _map_fastmcp_tool_results(parts: list[ContentBlock]) -> list[FastMCPToolResult]:
225+
"""Map FastMCP tool results to toolset tool results."""
226+
return [_map_fastmcp_tool_result(part) for part in parts]
227+
228+
229+
def _map_fastmcp_tool_result(part: ContentBlock) -> FastMCPToolResult:
230+
if isinstance(part, TextContent):
231+
text = part.text
232+
if text.startswith(('[', '{')):
233+
with contextlib.suppress(ValueError):
234+
result: Any = pydantic_core.from_json(text)
235+
if isinstance(result, dict | list):
236+
return result # pyright: ignore[reportUnknownVariableType, reportReturnType]
237+
return text
238+
239+
if isinstance(part, ImageContent):
240+
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
241+
242+
if isinstance(part, AudioContent):
243+
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
244+
245+
if isinstance(part, EmbeddedResource):
246+
resource = part.resource
247+
if isinstance(resource, TextResourceContents):
248+
return resource.text
249+
250+
# BlobResourceContents
251+
return messages.BinaryContent(
252+
data=base64.b64decode(resource.blob),
253+
media_type=resource.mimeType or 'application/octet-stream',
254+
)
255+
256+
msg = f'Unsupported/Unknown content block type: {type(part)}'
257+
raise ValueError(msg)

pydantic_ai_slim/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ cli = [
8989
]
9090
# MCP
9191
mcp = ["mcp>=1.12.3"]
92+
# FastMCP
93+
fastmcp = ["fastmcp>=2.12.0"]
9294
# Evals
9395
evals = ["pydantic-evals=={{ version }}"]
9496
# A2A

0 commit comments

Comments
 (0)