Skip to content

Commit 829e03b

Browse files
authored
Make FastMCPToolset work with Temporal (#3413)
1 parent 0e45d0d commit 829e03b

File tree

8 files changed

+1454
-147
lines changed

8 files changed

+1454
-147
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
7575
'httpx',
7676
'anyio',
7777
'httpcore',
78+
# Used by fastmcp via py-key-value-aio
79+
'beartype',
7880
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
7981
'attrs',
8082
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
from typing import Literal
4+
5+
from temporalio.workflow import ActivityConfig
6+
7+
from pydantic_ai import ToolsetTool
8+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
9+
from pydantic_ai.toolsets.fastmcp import FastMCPToolset
10+
11+
from ._mcp import TemporalMCPToolset
12+
from ._run_context import TemporalRunContext
13+
14+
15+
class TemporalFastMCPToolset(TemporalMCPToolset[AgentDepsT]):
16+
def __init__(
17+
self,
18+
toolset: FastMCPToolset[AgentDepsT],
19+
*,
20+
activity_name_prefix: str,
21+
activity_config: ActivityConfig,
22+
tool_activity_config: dict[str, ActivityConfig | Literal[False]],
23+
deps_type: type[AgentDepsT],
24+
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
25+
):
26+
super().__init__(
27+
toolset,
28+
activity_name_prefix=activity_name_prefix,
29+
activity_config=activity_config,
30+
tool_activity_config=tool_activity_config,
31+
deps_type=deps_type,
32+
run_context_type=run_context_type,
33+
)
34+
35+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
36+
assert isinstance(self.wrapped, FastMCPToolset)
37+
return self.wrapped.tool_for_tool_def(tool_def)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from collections.abc import Callable
5+
from dataclasses import dataclass
6+
from typing import Any, Literal
7+
8+
from pydantic import ConfigDict, with_config
9+
from temporalio import activity, workflow
10+
from temporalio.workflow import ActivityConfig
11+
from typing_extensions import Self
12+
13+
from pydantic_ai import ToolsetTool
14+
from pydantic_ai.exceptions import UserError
15+
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
16+
from pydantic_ai.toolsets import AbstractToolset
17+
18+
from ._run_context import TemporalRunContext
19+
from ._toolset import (
20+
CallToolParams,
21+
CallToolResult,
22+
TemporalWrapperToolset,
23+
)
24+
25+
26+
@dataclass
27+
@with_config(ConfigDict(arbitrary_types_allowed=True))
28+
class _GetToolsParams:
29+
serialized_run_context: Any
30+
31+
32+
class TemporalMCPToolset(TemporalWrapperToolset[AgentDepsT], ABC):
33+
def __init__(
34+
self,
35+
toolset: AbstractToolset[AgentDepsT],
36+
*,
37+
activity_name_prefix: str,
38+
activity_config: ActivityConfig,
39+
tool_activity_config: dict[str, ActivityConfig | Literal[False]],
40+
deps_type: type[AgentDepsT],
41+
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
42+
):
43+
super().__init__(toolset)
44+
self.activity_config = activity_config
45+
46+
self.tool_activity_config: dict[str, ActivityConfig] = {}
47+
for tool_name, tool_config in tool_activity_config.items():
48+
if tool_config is False:
49+
raise UserError(
50+
f'Temporal activity config for MCP tool {tool_name!r} has been explicitly set to `False` (activity disabled), '
51+
'but MCP tools require the use of IO and so cannot be run outside of an activity.'
52+
)
53+
self.tool_activity_config[tool_name] = tool_config
54+
55+
self.run_context_type = run_context_type
56+
57+
async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, ToolDefinition]:
58+
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
59+
tools = await self.wrapped.get_tools(run_context)
60+
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
61+
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
62+
return {name: tool.tool_def for name, tool in tools.items()}
63+
64+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
65+
get_tools_activity.__annotations__['deps'] = deps_type
66+
67+
self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__get_tools')(
68+
get_tools_activity
69+
)
70+
71+
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
72+
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
73+
assert isinstance(params.tool_def, ToolDefinition)
74+
return await self._wrap_call_tool_result(
75+
self.wrapped.call_tool(
76+
params.name,
77+
params.tool_args,
78+
run_context,
79+
self.tool_for_tool_def(params.tool_def),
80+
)
81+
)
82+
83+
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
84+
call_tool_activity.__annotations__['deps'] = deps_type
85+
86+
self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__call_tool')(
87+
call_tool_activity
88+
)
89+
90+
@abstractmethod
91+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
92+
raise NotImplementedError
93+
94+
@property
95+
def temporal_activities(self) -> list[Callable[..., Any]]:
96+
return [self.get_tools_activity, self.call_tool_activity]
97+
98+
async def __aenter__(self) -> Self:
99+
# The wrapped MCPServer enters itself around listing and calling tools
100+
# so we don't need to enter it here (nor could we because we're not inside a Temporal activity).
101+
return self
102+
103+
async def __aexit__(self, *args: Any) -> bool | None:
104+
return None
105+
106+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
107+
if not workflow.in_workflow():
108+
return await super().get_tools(ctx)
109+
110+
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
111+
tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
112+
activity=self.get_tools_activity,
113+
args=[
114+
_GetToolsParams(serialized_run_context=serialized_run_context),
115+
ctx.deps,
116+
],
117+
**self.activity_config,
118+
)
119+
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}
120+
121+
async def call_tool(
122+
self,
123+
name: str,
124+
tool_args: dict[str, Any],
125+
ctx: RunContext[AgentDepsT],
126+
tool: ToolsetTool[AgentDepsT],
127+
) -> CallToolResult:
128+
if not workflow.in_workflow():
129+
return await super().call_tool(name, tool_args, ctx, tool)
130+
131+
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
132+
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
133+
return self._unwrap_call_tool_result(
134+
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
135+
activity=self.call_tool_activity,
136+
args=[
137+
CallToolParams(
138+
name=name,
139+
tool_args=tool_args,
140+
serialized_run_context=serialized_run_context,
141+
tool_def=tool.tool_def,
142+
),
143+
ctx.deps,
144+
],
145+
**tool_activity_config,
146+
)
147+
)
Lines changed: 11 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,18 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable
4-
from dataclasses import dataclass
5-
from typing import Any, Literal
3+
from typing import Literal
64

7-
from pydantic import ConfigDict, with_config
8-
from temporalio import activity, workflow
95
from temporalio.workflow import ActivityConfig
10-
from typing_extensions import Self
116

127
from pydantic_ai import ToolsetTool
13-
from pydantic_ai.exceptions import UserError
148
from pydantic_ai.mcp import MCPServer
15-
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
9+
from pydantic_ai.tools import AgentDepsT, ToolDefinition
1610

11+
from ._mcp import TemporalMCPToolset
1712
from ._run_context import TemporalRunContext
18-
from ._toolset import (
19-
CallToolParams,
20-
CallToolResult,
21-
TemporalWrapperToolset,
22-
)
2313

2414

25-
@dataclass
26-
@with_config(ConfigDict(arbitrary_types_allowed=True))
27-
class _GetToolsParams:
28-
serialized_run_context: Any
29-
30-
31-
class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
15+
class TemporalMCPServer(TemporalMCPToolset[AgentDepsT]):
3216
def __init__(
3317
self,
3418
server: MCPServer,
@@ -39,108 +23,15 @@ def __init__(
3923
deps_type: type[AgentDepsT],
4024
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
4125
):
42-
super().__init__(server)
43-
self.activity_config = activity_config
44-
45-
self.tool_activity_config: dict[str, ActivityConfig] = {}
46-
for tool_name, tool_config in tool_activity_config.items():
47-
if tool_config is False:
48-
raise UserError(
49-
f'Temporal activity config for MCP tool {tool_name!r} has been explicitly set to `False` (activity disabled), '
50-
'but MCP tools require the use of IO and so cannot be run outside of an activity.'
51-
)
52-
self.tool_activity_config[tool_name] = tool_config
53-
54-
self.run_context_type = run_context_type
55-
56-
async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, ToolDefinition]:
57-
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
58-
tools = await self.wrapped.get_tools(run_context)
59-
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
60-
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
61-
return {name: tool.tool_def for name, tool in tools.items()}
62-
63-
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
64-
get_tools_activity.__annotations__['deps'] = deps_type
65-
66-
self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__get_tools')(
67-
get_tools_activity
68-
)
69-
70-
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
71-
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
72-
assert isinstance(params.tool_def, ToolDefinition)
73-
return await self._wrap_call_tool_result(
74-
self.wrapped.call_tool(
75-
params.name,
76-
params.tool_args,
77-
run_context,
78-
self.tool_for_tool_def(params.tool_def),
79-
)
80-
)
81-
82-
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
83-
call_tool_activity.__annotations__['deps'] = deps_type
84-
85-
self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__call_tool')(
86-
call_tool_activity
26+
super().__init__(
27+
server,
28+
activity_name_prefix=activity_name_prefix,
29+
activity_config=activity_config,
30+
tool_activity_config=tool_activity_config,
31+
deps_type=deps_type,
32+
run_context_type=run_context_type,
8733
)
8834

8935
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
9036
assert isinstance(self.wrapped, MCPServer)
9137
return self.wrapped.tool_for_tool_def(tool_def)
92-
93-
@property
94-
def temporal_activities(self) -> list[Callable[..., Any]]:
95-
return [self.get_tools_activity, self.call_tool_activity]
96-
97-
async def __aenter__(self) -> Self:
98-
# The wrapped MCPServer enters itself around listing and calling tools
99-
# so we don't need to enter it here (nor could we because we're not inside a Temporal activity).
100-
return self
101-
102-
async def __aexit__(self, *args: Any) -> bool | None:
103-
return None
104-
105-
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
106-
if not workflow.in_workflow():
107-
return await super().get_tools(ctx)
108-
109-
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
110-
tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
111-
activity=self.get_tools_activity,
112-
args=[
113-
_GetToolsParams(serialized_run_context=serialized_run_context),
114-
ctx.deps,
115-
],
116-
**self.activity_config,
117-
)
118-
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}
119-
120-
async def call_tool(
121-
self,
122-
name: str,
123-
tool_args: dict[str, Any],
124-
ctx: RunContext[AgentDepsT],
125-
tool: ToolsetTool[AgentDepsT],
126-
) -> CallToolResult:
127-
if not workflow.in_workflow():
128-
return await super().call_tool(name, tool_args, ctx, tool)
129-
130-
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
131-
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
132-
return self._unwrap_call_tool_result(
133-
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
134-
activity=self.call_tool_activity,
135-
args=[
136-
CallToolParams(
137-
name=name,
138-
tool_args=tool_args,
139-
serialized_run_context=serialized_run_context,
140-
tool_def=tool.tool_def,
141-
),
142-
ctx.deps,
143-
],
144-
**tool_activity_config,
145-
)
146-
)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,21 @@ def temporalize_toolset(
142142
run_context_type=run_context_type,
143143
)
144144

145+
try:
146+
from pydantic_ai.toolsets.fastmcp import FastMCPToolset
147+
148+
from ._fastmcp_toolset import TemporalFastMCPToolset
149+
except ImportError:
150+
pass
151+
else:
152+
if isinstance(toolset, FastMCPToolset):
153+
return TemporalFastMCPToolset(
154+
toolset,
155+
activity_name_prefix=activity_name_prefix,
156+
activity_config=activity_config,
157+
tool_activity_config=tool_activity_config,
158+
deps_type=deps_type,
159+
run_context_type=run_context_type,
160+
)
161+
145162
return toolset

0 commit comments

Comments
 (0)