Skip to content
56 changes: 46 additions & 10 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal
from typing import Annotated, Any, Literal

from pydantic import ConfigDict, with_config
from pydantic import ConfigDict, Discriminator, with_config
from temporalio import activity, workflow
from temporalio.workflow import ActivityConfig
from typing_extensions import Self

from pydantic_ai import ToolsetTool
from pydantic_ai.exceptions import UserError
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
from pydantic_ai.mcp import MCPServer, ToolResult
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition

Expand All @@ -33,6 +33,34 @@ class _CallToolParams:
tool_def: ToolDefinition


@dataclass
class _ApprovalRequired:
kind: Literal['approval_required'] = 'approval_required'


@dataclass
class _CallDeferred:
kind: Literal['call_deferred'] = 'call_deferred'


@dataclass
class _ModelRetry:
message: str
kind: Literal['model_retry'] = 'model_retry'


@dataclass
class _ToolReturn:
result: Any
kind: Literal['tool_return'] = 'tool_return'


_CallToolResult = Annotated[
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
Discriminator('kind'),
]


class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
def __init__(
self,
Expand Down Expand Up @@ -72,14 +100,22 @@ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[
get_tools_activity
)

async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> ToolResult:
async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
return await self.wrapped.call_tool(
params.name,
params.tool_args,
run_context,
self.tool_for_tool_def(params.tool_def),
)
try:
result = await self.wrapped.call_tool(
params.name,
params.tool_args,
run_context,
self.tool_for_tool_def(params.tool_def),
)
return _ToolReturn(result=result)
except ApprovalRequired:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please move this logic and the types to TemporalWrapperToolset, so that both TemporalFunctionToolset and TemporalMCPServer can use them?

return _ApprovalRequired()
except CallDeferred:
return _CallDeferred()
except ModelRetry as e:
return _ModelRetry(message=e.message)

# Set type hint explicitly so that Temporal can take care of serialization and deserialization
call_tool_activity.__annotations__['deps'] = deps_type
Expand Down
Loading