|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from collections.abc import Callable |
4 | | -from dataclasses import dataclass |
5 | | -from typing import Any, Literal |
| 3 | +from typing import Literal |
6 | 4 |
|
7 | | -from pydantic import ConfigDict, with_config |
8 | | -from temporalio import activity, workflow |
9 | 5 | from temporalio.workflow import ActivityConfig |
10 | | -from typing_extensions import Self |
11 | 6 |
|
12 | 7 | from pydantic_ai import ToolsetTool |
13 | | -from pydantic_ai.exceptions import UserError |
14 | | -from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition |
| 8 | +from pydantic_ai.tools import AgentDepsT, ToolDefinition |
15 | 9 | from pydantic_ai.toolsets.fastmcp import FastMCPToolset |
16 | 10 |
|
| 11 | +from ._mcp import TemporalMCPToolset |
17 | 12 | from ._run_context import TemporalRunContext |
18 | | -from ._toolset import ( |
19 | | - CallToolParams, |
20 | | - CallToolResult, |
21 | | - TemporalWrapperToolset, |
22 | | -) |
23 | 13 |
|
24 | 14 |
|
25 | | -@dataclass |
26 | | -@with_config(ConfigDict(arbitrary_types_allowed=True)) |
27 | | -class _GetToolsParams: |
28 | | - serialized_run_context: Any |
29 | | - |
30 | | - |
31 | | -class TemporalFastMCPToolset(TemporalWrapperToolset[AgentDepsT]): |
| 15 | +class TemporalFastMCPToolset(TemporalMCPToolset[AgentDepsT]): |
32 | 16 | def __init__( |
33 | 17 | self, |
34 | | - server: FastMCPToolset[AgentDepsT], |
| 18 | + toolset: FastMCPToolset[AgentDepsT], |
35 | 19 | *, |
36 | 20 | activity_name_prefix: str, |
37 | 21 | activity_config: ActivityConfig, |
38 | 22 | tool_activity_config: dict[str, ActivityConfig | Literal[False]], |
39 | 23 | deps_type: type[AgentDepsT], |
40 | 24 | run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT], |
41 | 25 | ): |
42 | | - super().__init__(server) |
43 | | - self.activity_config = activity_config |
44 | | - |
45 | | - self.tool_activity_config: dict[str, ActivityConfig] = {} |
46 | | - for tool_name, tool_config in tool_activity_config.items(): |
47 | | - if tool_config is False: |
48 | | - raise UserError( |
49 | | - f'Temporal activity config for MCP tool {tool_name!r} has been explicitly set to `False` (activity disabled), ' |
50 | | - 'but MCP tools require the use of IO and so cannot be run outside of an activity.' |
51 | | - ) |
52 | | - self.tool_activity_config[tool_name] = tool_config |
53 | | - |
54 | | - self.run_context_type = run_context_type |
55 | | - |
56 | | - async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, ToolDefinition]: |
57 | | - run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps) |
58 | | - tools = await self.wrapped.get_tools(run_context) |
59 | | - # ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time), |
60 | | - # so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity. |
61 | | - return {name: tool.tool_def for name, tool in tools.items()} |
62 | | - |
63 | | - # Set type hint explicitly so that Temporal can take care of serialization and deserialization |
64 | | - get_tools_activity.__annotations__['deps'] = deps_type |
65 | | - |
66 | | - self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__get_tools')( |
67 | | - get_tools_activity |
68 | | - ) |
69 | | - |
70 | | - async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult: |
71 | | - run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps) |
72 | | - assert isinstance(params.tool_def, ToolDefinition) |
73 | | - return await self._wrap_call_tool_result( |
74 | | - self.wrapped.call_tool( |
75 | | - params.name, |
76 | | - params.tool_args, |
77 | | - run_context, |
78 | | - self.tool_for_tool_def(params.tool_def), |
79 | | - ) |
80 | | - ) |
81 | | - |
82 | | - # Set type hint explicitly so that Temporal can take care of serialization and deserialization |
83 | | - call_tool_activity.__annotations__['deps'] = deps_type |
84 | | - |
85 | | - self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__call_tool')( |
86 | | - call_tool_activity |
| 26 | + super().__init__( |
| 27 | + toolset, |
| 28 | + activity_name_prefix=activity_name_prefix, |
| 29 | + activity_config=activity_config, |
| 30 | + tool_activity_config=tool_activity_config, |
| 31 | + deps_type=deps_type, |
| 32 | + run_context_type=run_context_type, |
87 | 33 | ) |
88 | 34 |
|
89 | 35 | def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]: |
90 | 36 | assert isinstance(self.wrapped, FastMCPToolset) |
91 | 37 | return self.wrapped.tool_for_tool_def(tool_def) |
92 | | - |
93 | | - @property |
94 | | - def temporal_activities(self) -> list[Callable[..., Any]]: |
95 | | - return [self.get_tools_activity, self.call_tool_activity] |
96 | | - |
97 | | - async def __aenter__(self) -> Self: |
98 | | - # The wrapped MCPServer enters itself around listing and calling tools |
99 | | - # so we don't need to enter it here (nor could we because we're not inside a Temporal activity). |
100 | | - return self |
101 | | - |
102 | | - async def __aexit__(self, *args: Any) -> bool | None: |
103 | | - return None |
104 | | - |
105 | | - async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: |
106 | | - if not workflow.in_workflow(): |
107 | | - return await super().get_tools(ctx) |
108 | | - |
109 | | - serialized_run_context = self.run_context_type.serialize_run_context(ctx) |
110 | | - tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] |
111 | | - activity=self.get_tools_activity, |
112 | | - args=[ |
113 | | - _GetToolsParams(serialized_run_context=serialized_run_context), |
114 | | - ctx.deps, |
115 | | - ], |
116 | | - **self.activity_config, |
117 | | - ) |
118 | | - return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()} |
119 | | - |
120 | | - async def call_tool( |
121 | | - self, |
122 | | - name: str, |
123 | | - tool_args: dict[str, Any], |
124 | | - ctx: RunContext[AgentDepsT], |
125 | | - tool: ToolsetTool[AgentDepsT], |
126 | | - ) -> CallToolResult: |
127 | | - if not workflow.in_workflow(): |
128 | | - return await super().call_tool(name, tool_args, ctx, tool) |
129 | | - |
130 | | - tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {}) |
131 | | - serialized_run_context = self.run_context_type.serialize_run_context(ctx) |
132 | | - return self._unwrap_call_tool_result( |
133 | | - await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType] |
134 | | - activity=self.call_tool_activity, |
135 | | - args=[ |
136 | | - CallToolParams( |
137 | | - name=name, |
138 | | - tool_args=tool_args, |
139 | | - serialized_run_context=serialized_run_context, |
140 | | - tool_def=tool.tool_def, |
141 | | - ), |
142 | | - ctx.deps, |
143 | | - ], |
144 | | - **tool_activity_config, |
145 | | - ) |
146 | | - ) |
0 commit comments