Skip to content

Commit a9ddabe

Browse files
committed
Make FastMCPToolset work with Temporal
1 parent 1df9ca6 commit a9ddabe

File tree

5 files changed

+1403
-27
lines changed

5 files changed

+1403
-27
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from dataclasses import dataclass
5+
from typing import Any, Literal
6+
7+
from pydantic import ConfigDict, with_config
8+
from temporalio import activity, workflow
9+
from temporalio.workflow import ActivityConfig
10+
from typing_extensions import Self
11+
12+
from pydantic_ai import ToolsetTool
13+
from pydantic_ai.exceptions import UserError
14+
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
15+
from pydantic_ai.toolsets.fastmcp import FastMCPToolset
16+
17+
from ._run_context import TemporalRunContext
18+
from ._toolset import (
19+
CallToolParams,
20+
CallToolResult,
21+
TemporalWrapperToolset,
22+
)
23+
24+
25+
@dataclass
26+
@with_config(ConfigDict(arbitrary_types_allowed=True))
27+
class _GetToolsParams:
28+
serialized_run_context: Any
29+
30+
31+
class TemporalFastMCPToolset(TemporalWrapperToolset[AgentDepsT]):
32+
def __init__(
33+
self,
34+
server: FastMCPToolset[AgentDepsT],
35+
*,
36+
activity_name_prefix: str,
37+
activity_config: ActivityConfig,
38+
tool_activity_config: dict[str, ActivityConfig | Literal[False]],
39+
deps_type: type[AgentDepsT],
40+
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
41+
):
42+
super().__init__(server)
43+
self.activity_config = activity_config
44+
45+
self.tool_activity_config: dict[str, ActivityConfig] = {}
46+
for tool_name, tool_config in tool_activity_config.items():
47+
if tool_config is False:
48+
raise UserError(
49+
f'Temporal activity config for MCP tool {tool_name!r} has been explicitly set to `False` (activity disabled), '
50+
'but MCP tools require the use of IO and so cannot be run outside of an activity.'
51+
)
52+
self.tool_activity_config[tool_name] = tool_config
53+
54+
self.run_context_type = run_context_type
55+
56+
async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, ToolDefinition]:
57+
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
58+
tools = await self.wrapped.get_tools(run_context)
59+
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
60+
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
61+
return {name: tool.tool_def for name, tool in tools.items()}
62+
63+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
64+
get_tools_activity.__annotations__['deps'] = deps_type
65+
66+
self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__get_tools')(
67+
get_tools_activity
68+
)
69+
70+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
71+
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
72+
assert isinstance(params.tool_def, ToolDefinition)
73+
return await self._wrap_call_tool_result(
74+
self.wrapped.call_tool(
75+
params.name,
76+
params.tool_args,
77+
run_context,
78+
self.tool_for_tool_def(params.tool_def),
79+
)
80+
)
81+
82+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
83+
call_tool_activity.__annotations__['deps'] = deps_type
84+
85+
self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__call_tool')(
86+
call_tool_activity
87+
)
88+
89+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
90+
assert isinstance(self.wrapped, FastMCPToolset)
91+
return self.wrapped.tool_for_tool_def(tool_def)
92+
93+
@property
94+
def temporal_activities(self) -> list[Callable[..., Any]]:
95+
return [self.get_tools_activity, self.call_tool_activity]
96+
97+
async def __aenter__(self) -> Self:
98+
# The wrapped MCPServer enters itself around listing and calling tools
99+
# so we don't need to enter it here (nor could we because we're not inside a Temporal activity).
100+
return self
101+
102+
async def __aexit__(self, *args: Any) -> bool | None:
103+
return None
104+
105+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
106+
if not workflow.in_workflow():
107+
return await super().get_tools(ctx)
108+
109+
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
110+
tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
111+
activity=self.get_tools_activity,
112+
args=[
113+
_GetToolsParams(serialized_run_context=serialized_run_context),
114+
ctx.deps,
115+
],
116+
**self.activity_config,
117+
)
118+
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}
119+
120+
async def call_tool(
121+
self,
122+
name: str,
123+
tool_args: dict[str, Any],
124+
ctx: RunContext[AgentDepsT],
125+
tool: ToolsetTool[AgentDepsT],
126+
) -> CallToolResult:
127+
if not workflow.in_workflow():
128+
return await super().call_tool(name, tool_args, ctx, tool)
129+
130+
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
131+
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
132+
return self._unwrap_call_tool_result(
133+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
134+
activity=self.call_tool_activity,
135+
args=[
136+
CallToolParams(
137+
name=name,
138+
tool_args=tool_args,
139+
serialized_run_context=serialized_run_context,
140+
tool_def=tool.tool_def,
141+
),
142+
ctx.deps,
143+
],
144+
**tool_activity_config,
145+
)
146+
)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,21 @@ def temporalize_toolset(
142142
run_context_type=run_context_type,
143143
)
144144

145+
try:
146+
from pydantic_ai.toolsets.fastmcp import FastMCPToolset
147+
148+
from ._fastmcp_toolset import TemporalFastMCPToolset
149+
except ImportError:
150+
pass
151+
else:
152+
if isinstance(toolset, FastMCPToolset):
153+
return TemporalFastMCPToolset(
154+
toolset,
155+
activity_name_prefix=activity_name_prefix,
156+
activity_config=activity_config,
157+
tool_activity_config=tool_activity_config,
158+
deps_type=deps_type,
159+
run_context_type=run_context_type,
160+
)
161+
145162
return toolset

pydantic_ai_slim/pydantic_ai/toolsets/fastmcp.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
ResourceLink,
3333
TextContent,
3434
TextResourceContents,
35-
Tool as MCPTool,
3635
)
3736

3837
from pydantic_ai.mcp import TOOL_SCHEMA_VALIDATOR
@@ -131,11 +130,20 @@ async def __aexit__(self, *args: Any) -> bool | None:
131130

132131
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
133132
async with self:
134-
mcp_tools: list[MCPTool] = await self.client.list_tools()
135-
136133
return {
137-
tool.name: _convert_mcp_tool_to_toolset_tool(toolset=self, mcp_tool=tool, retries=self.max_retries)
138-
for tool in mcp_tools
134+
mcp_tool.name: self.tool_for_tool_def(
135+
ToolDefinition(
136+
name=mcp_tool.name,
137+
description=mcp_tool.description,
138+
parameters_json_schema=mcp_tool.inputSchema,
139+
metadata={
140+
'meta': mcp_tool.meta,
141+
'annotations': mcp_tool.annotations.model_dump() if mcp_tool.annotations else None,
142+
'output_schema': mcp_tool.outputSchema or None,
143+
},
144+
)
145+
)
146+
for mcp_tool in await self.client.list_tools()
139147
}
140148

141149
async def call_tool(
@@ -157,28 +165,13 @@ async def call_tool(
157165
# Otherwise, return the content
158166
return _map_fastmcp_tool_results(parts=call_tool_result.content)
159167

160-
161-
def _convert_mcp_tool_to_toolset_tool(
162-
toolset: FastMCPToolset[AgentDepsT],
163-
mcp_tool: MCPTool,
164-
retries: int,
165-
) -> ToolsetTool[AgentDepsT]:
166-
"""Convert an MCP tool to a toolset tool."""
167-
return ToolsetTool[AgentDepsT](
168-
tool_def=ToolDefinition(
169-
name=mcp_tool.name,
170-
description=mcp_tool.description,
171-
parameters_json_schema=mcp_tool.inputSchema,
172-
metadata={
173-
'meta': mcp_tool.meta,
174-
'annotations': mcp_tool.annotations.model_dump() if mcp_tool.annotations else None,
175-
'output_schema': mcp_tool.outputSchema or None,
176-
},
177-
),
178-
toolset=toolset,
179-
max_retries=retries,
180-
args_validator=TOOL_SCHEMA_VALIDATOR,
181-
)
168+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
169+
return ToolsetTool[AgentDepsT](
170+
tool_def=tool_def,
171+
toolset=self,
172+
max_retries=self.max_retries,
173+
args_validator=TOOL_SCHEMA_VALIDATOR,
174+
)
182175

183176

184177
def _map_fastmcp_tool_results(parts: list[ContentBlock]) -> list[FastMCPToolResult] | FastMCPToolResult:

0 commit comments

Comments
 (0)