Skip to content
Original file line number Diff line number Diff line change
@@ -1,57 +1,22 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import Annotated, Any, Literal
from typing import Any, Literal

from pydantic import ConfigDict, Discriminator, with_config
from temporalio import activity, workflow
from temporalio.workflow import ActivityConfig
from typing_extensions import assert_never

from pydantic_ai import FunctionToolset, ToolsetTool
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
from pydantic_ai.exceptions import UserError
from pydantic_ai.tools import AgentDepsT, RunContext
from pydantic_ai.toolsets.function import FunctionToolsetTool

from ._run_context import TemporalRunContext
from ._toolset import TemporalWrapperToolset


@dataclass
@with_config(ConfigDict(arbitrary_types_allowed=True))
class _CallToolParams:
name: str
tool_args: dict[str, Any]
serialized_run_context: Any


@dataclass
class _ApprovalRequired:
kind: Literal['approval_required'] = 'approval_required'


@dataclass
class _CallDeferred:
kind: Literal['call_deferred'] = 'call_deferred'


@dataclass
class _ModelRetry:
message: str
kind: Literal['model_retry'] = 'model_retry'


@dataclass
class _ToolReturn:
result: Any
kind: Literal['tool_return'] = 'tool_return'


_CallToolResult = Annotated[
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
Discriminator('kind'),
]
from ._toolset import (
CallToolParams,
CallToolResult,
TemporalWrapperToolset,
)


class TemporalFunctionToolset(TemporalWrapperToolset[AgentDepsT]):
Expand All @@ -70,7 +35,7 @@ def __init__(
self.tool_activity_config = tool_activity_config
self.run_context_type = run_context_type

async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
name = params.name
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
try:
Expand All @@ -84,15 +49,7 @@ async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _Call
# The tool args will already have been validated into their proper types in the `ToolManager`,
# but `execute_activity` would have turned them into simple Python types again, so we need to re-validate them.
args_dict = tool.args_validator.validate_python(params.tool_args)
try:
result = await self.wrapped.call_tool(name, args_dict, ctx, tool)
return _ToolReturn(result=result)
except ApprovalRequired:
return _ApprovalRequired()
except CallDeferred:
return _CallDeferred()
except ModelRetry as e:
return _ModelRetry(message=e.message)
return await self._wrap_call_tool_result(self.wrapped.call_tool(name, args_dict, ctx, tool))

# Set type hint explicitly so that Temporal can take care of serialization and deserialization
call_tool_activity.__annotations__['deps'] = deps_type
Expand Down Expand Up @@ -123,25 +80,18 @@ async def call_tool(

tool_activity_config = self.activity_config | tool_activity_config
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
result = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
activity=self.call_tool_activity,
args=[
_CallToolParams(
name=name,
tool_args=tool_args,
serialized_run_context=serialized_run_context,
),
ctx.deps,
],
**tool_activity_config,
return self._unwrap_call_tool_result(
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
activity=self.call_tool_activity,
args=[
CallToolParams(
name=name,
tool_args=tool_args,
serialized_run_context=serialized_run_context,
tool_def=None,
),
ctx.deps,
],
**tool_activity_config,
)
)
if isinstance(result, _ApprovalRequired):
raise ApprovalRequired()
elif isinstance(result, _CallDeferred):
raise CallDeferred()
elif isinstance(result, _ModelRetry):
raise ModelRetry(result.message)
elif isinstance(result, _ToolReturn):
return result.result
else:
assert_never(result)
60 changes: 30 additions & 30 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@

from pydantic_ai import ToolsetTool
from pydantic_ai.exceptions import UserError
from pydantic_ai.mcp import MCPServer, ToolResult
from pydantic_ai.mcp import MCPServer
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition

from ._run_context import TemporalRunContext
from ._toolset import TemporalWrapperToolset
from ._toolset import (
CallToolParams,
CallToolResult,
TemporalWrapperToolset,
)


@dataclass
Expand All @@ -24,15 +28,6 @@ class _GetToolsParams:
serialized_run_context: Any


@dataclass
@with_config(ConfigDict(arbitrary_types_allowed=True))
class _CallToolParams:
name: str
tool_args: dict[str, Any]
serialized_run_context: Any
tool_def: ToolDefinition


class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
def __init__(
self,
Expand Down Expand Up @@ -72,13 +67,16 @@ async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[
get_tools_activity
)

async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> ToolResult:
async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
return await self.wrapped.call_tool(
params.name,
params.tool_args,
run_context,
self.tool_for_tool_def(params.tool_def),
assert isinstance(params.tool_def, ToolDefinition)
return await self._wrap_call_tool_result(
self.wrapped.call_tool(
params.name,
params.tool_args,
run_context,
self.tool_for_tool_def(params.tool_def),
)
)

# Set type hint explicitly so that Temporal can take care of serialization and deserialization
Expand Down Expand Up @@ -125,22 +123,24 @@ async def call_tool(
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> ToolResult:
) -> CallToolResult:
if not workflow.in_workflow():
return await super().call_tool(name, tool_args, ctx, tool)

tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
activity=self.call_tool_activity,
args=[
_CallToolParams(
name=name,
tool_args=tool_args,
serialized_run_context=serialized_run_context,
tool_def=tool.tool_def,
),
ctx.deps,
],
**tool_activity_config,
return self._unwrap_call_tool_result(
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
activity=self.call_tool_activity,
args=[
CallToolParams(
name=name,
tool_args=tool_args,
serialized_run_context=serialized_run_context,
tool_def=tool.tool_def,
),
ctx.deps,
],
**tool_activity_config,
)
)
70 changes: 67 additions & 3 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,58 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, Literal
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Annotated, Any, Literal

from pydantic import ConfigDict, Discriminator, with_config
from temporalio.workflow import ActivityConfig
from typing_extensions import assert_never

from pydantic_ai import AbstractToolset, FunctionToolset, WrapperToolset
from pydantic_ai.tools import AgentDepsT
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry
from pydantic_ai.tools import AgentDepsT, ToolDefinition

from ._run_context import TemporalRunContext


@dataclass
@with_config(ConfigDict(arbitrary_types_allowed=True))
class CallToolParams:
name: str
tool_args: dict[str, Any]
serialized_run_context: Any
tool_def: ToolDefinition | None


@dataclass
class _ApprovalRequired:
kind: Literal['approval_required'] = 'approval_required'


@dataclass
class _CallDeferred:
kind: Literal['call_deferred'] = 'call_deferred'


@dataclass
class _ModelRetry:
message: str
kind: Literal['model_retry'] = 'model_retry'


@dataclass
class _ToolReturn:
result: Any
kind: Literal['tool_return'] = 'tool_return'


CallToolResult = Annotated[
_ApprovalRequired | _CallDeferred | _ModelRetry | _ToolReturn,
Discriminator('kind'),
]


class TemporalWrapperToolset(WrapperToolset[AgentDepsT], ABC):
@property
def id(self) -> str:
Expand All @@ -30,6 +71,29 @@ def visit_and_replace(
# Temporalized toolsets cannot be swapped out after the fact.
return self

async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
try:
result = await coro
return _ToolReturn(result=result)
except ApprovalRequired:
return _ApprovalRequired()
except CallDeferred:
return _CallDeferred()
except ModelRetry as e:
return _ModelRetry(message=e.message)

def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
if isinstance(result, _ToolReturn):
return result.result
elif isinstance(result, _ApprovalRequired):
raise ApprovalRequired()
elif isinstance(result, _CallDeferred):
raise CallDeferred()
elif isinstance(result, _ModelRetry):
raise ModelRetry(result.message)
else:
assert_never(result)


def temporalize_toolset(
toolset: AbstractToolset[AgentDepsT],
Expand Down