11from __future__ import annotations
22
33from collections .abc import Callable
4+ from dataclasses import dataclass
45from typing import Any , Literal
56
7+ from pydantic import ConfigDict , with_config
68from temporalio import activity , workflow
79from temporalio .workflow import ActivityConfig
810from typing_extensions import Self
1416
1517from ._run_context import TemporalRunContext
1618from ._toolset import (
17- CallToolParamsData ,
18- CallToolResultData ,
19- GetToolsParamsData ,
19+ CallToolParams ,
20+ CallToolResult ,
2021 TemporalWrapperToolset ,
21- ToolReturnData ,
22- remap_dataclass_to_exception ,
23- remap_exception_to_dataclass ,
2422)
2523
2624
25+ @dataclass
26+ @with_config (ConfigDict (arbitrary_types_allowed = True ))
27+ class _GetToolsParams :
28+ serialized_run_context : Any
29+
30+
2731class TemporalMCPServer (TemporalWrapperToolset [AgentDepsT ]):
2832 def __init__ (
2933 self ,
@@ -49,7 +53,7 @@ def __init__(
4953
5054 self .run_context_type = run_context_type
5155
52- async def get_tools_activity (params : GetToolsParamsData , deps : AgentDepsT ) -> dict [str , ToolDefinition ]:
56+ async def get_tools_activity (params : _GetToolsParams , deps : AgentDepsT ) -> dict [str , ToolDefinition ]:
5357 run_context = self .run_context_type .deserialize_run_context (params .serialized_run_context , deps = deps )
5458 tools = await self .wrapped .get_tools (run_context )
5559 # ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
@@ -63,19 +67,17 @@ async def get_tools_activity(params: GetToolsParamsData, deps: AgentDepsT) -> di
6367 get_tools_activity
6468 )
6569
66- async def call_tool_activity (params : CallToolParamsData , deps : AgentDepsT ) -> CallToolResultData :
70+ async def call_tool_activity (params : CallToolParams , deps : AgentDepsT ) -> CallToolResult :
6771 run_context = self .run_context_type .deserialize_run_context (params .serialized_run_context , deps = deps )
68- try :
69- assert isinstance ( params . tool_def , ToolDefinition )
70- result = await self .wrapped .call_tool (
72+ assert isinstance ( params . tool_def , ToolDefinition )
73+ return await self . _wrap_call_tool_result (
74+ self .wrapped .call_tool (
7175 params .name ,
7276 params .tool_args ,
7377 run_context ,
7478 self .tool_for_tool_def (params .tool_def ),
7579 )
76- return ToolReturnData (result = result )
77- except Exception as e :
78- return remap_exception_to_dataclass (e )
80+ )
7981
8082 # Set type hint explicitly so that Temporal can take care of serialization and deserialization
8183 call_tool_activity .__annotations__ ['deps' ] = deps_type
@@ -108,7 +110,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
108110 tool_defs = await workflow .execute_activity ( # pyright: ignore[reportUnknownMemberType]
109111 activity = self .get_tools_activity ,
110112 args = [
111- GetToolsParamsData (serialized_run_context = serialized_run_context ),
113+ _GetToolsParams (serialized_run_context = serialized_run_context ),
112114 ctx .deps ,
113115 ],
114116 ** self .activity_config ,
@@ -121,23 +123,24 @@ async def call_tool(
121123 tool_args : dict [str , Any ],
122124 ctx : RunContext [AgentDepsT ],
123125 tool : ToolsetTool [AgentDepsT ],
124- ) -> CallToolResultData :
126+ ) -> CallToolResult :
125127 if not workflow .in_workflow ():
126128 return await super ().call_tool (name , tool_args , ctx , tool )
127129
128130 tool_activity_config = self .activity_config | self .tool_activity_config .get (name , {})
129131 serialized_run_context = self .run_context_type .serialize_run_context (ctx )
130- result = await workflow .execute_activity ( # pyright: ignore[reportUnknownMemberType]
131- activity = self .call_tool_activity ,
132- args = [
133- CallToolParamsData (
134- name = name ,
135- tool_args = tool_args ,
136- serialized_run_context = serialized_run_context ,
137- tool_def = tool .tool_def ,
138- ),
139- ctx .deps ,
140- ],
141- ** tool_activity_config ,
132+ return self ._unwrap_call_tool_result (
133+ await workflow .execute_activity ( # pyright: ignore[reportUnknownMemberType]
134+ activity = self .call_tool_activity ,
135+ args = [
136+ CallToolParams (
137+ name = name ,
138+ tool_args = tool_args ,
139+ serialized_run_context = serialized_run_context ,
140+ tool_def = tool .tool_def ,
141+ ),
142+ ctx .deps ,
143+ ],
144+ ** tool_activity_config ,
145+ )
142146 )
143- return remap_dataclass_to_exception (result )
0 commit comments