4
4
import functools
5
5
import json
6
6
from abc import ABC , abstractmethod
7
- from collections .abc import AsyncIterator , Sequence
7
+ from collections .abc import AsyncIterator , Awaitable , Sequence
8
8
from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
9
9
from dataclasses import dataclass
10
10
from pathlib import Path
20
20
from mcp .types import (
21
21
AudioContent ,
22
22
BlobResourceContents ,
23
+ CallToolRequest ,
24
+ CallToolRequestParams ,
25
+ CallToolResult ,
26
+ ClientRequest ,
23
27
Content ,
24
28
EmbeddedResource ,
25
29
ImageContent ,
26
30
LoggingLevel ,
31
+ RequestParams ,
27
32
TextContent ,
28
33
TextResourceContents ,
29
34
)
30
35
from typing_extensions import Self , assert_never , deprecated
31
36
32
37
from pydantic_ai .exceptions import ModelRetry
33
38
from pydantic_ai .messages import BinaryContent
34
- from pydantic_ai .tools import ToolDefinition
39
+ from pydantic_ai .tools import RunContext , ToolDefinition
35
40
36
41
try :
37
42
from mcp .client .session import ClientSession
@@ -61,6 +66,9 @@ class MCPServer(ABC):
61
66
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
62
67
"""
63
68
69
+ process_tool_call : ProcessToolCallback | None = None
70
+ """Hook to customize tool calling and optionally pass extra metadata."""
71
+
64
72
_client : ClientSession
65
73
_read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ]
66
74
_write_stream : MemoryObjectSendStream [SessionMessage ]
@@ -114,13 +122,17 @@ async def list_tools(self) -> list[ToolDefinition]:
114
122
]
115
123
116
124
async def call_tool (
117
- self , tool_name : str , arguments : dict [str , Any ]
118
- ) -> str | BinaryContent | dict [str , Any ] | list [Any ] | Sequence [str | BinaryContent | dict [str , Any ] | list [Any ]]:
125
+ self ,
126
+ tool_name : str ,
127
+ arguments : dict [str , Any ],
128
+ metadata : dict [str , Any ] | None = None ,
129
+ ) -> ToolResult :
119
130
"""Call a tool on the server.
120
131
121
132
Args:
122
133
tool_name: The name of the tool to call.
123
134
arguments: The arguments to pass to the tool.
135
+ metadata: Request-level metadata (optional)
124
136
125
137
Returns:
126
138
The result of the tool call.
@@ -129,7 +141,20 @@ async def call_tool(
129
141
ModelRetry: If the tool call fails.
130
142
"""
131
143
try :
132
- result = await self ._client .call_tool (self .get_unprefixed_tool_name (tool_name ), arguments )
144
+ # meta param is not provided by session yet, so build and can send_request directly.
145
+ result = await self ._client .send_request (
146
+ ClientRequest (
147
+ CallToolRequest (
148
+ method = 'tools/call' ,
149
+ params = CallToolRequestParams (
150
+ name = self .get_unprefixed_tool_name (tool_name ),
151
+ arguments = arguments ,
152
+ _meta = RequestParams .Meta (** metadata ) if metadata else None ,
153
+ ),
154
+ )
155
+ ),
156
+ CallToolResult ,
157
+ )
133
158
except McpError as e :
134
159
raise ModelRetry (e .error .message )
135
160
@@ -269,6 +294,9 @@ async def main():
269
294
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
270
295
"""
271
296
297
+ process_tool_call : ProcessToolCallback | None = None
298
+ """Hook to customize tool calling and optionally pass extra metadata."""
299
+
272
300
timeout : float = 5
273
301
""" The timeout in seconds to wait for the client to initialize."""
274
302
@@ -363,6 +391,9 @@ class _MCPServerHTTP(MCPServer):
363
391
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
364
392
"""
365
393
394
+ process_tool_call : ProcessToolCallback | None = None
395
+ """Hook to customize tool calling and optionally pass extra metadata."""
396
+
366
397
@property
367
398
@abstractmethod
368
399
def _transport_client (
@@ -521,3 +552,29 @@ async def main():
521
552
@property
522
553
def _transport_client (self ):
523
554
return streamablehttp_client # pragma: no cover
555
+
556
+
557
+ ToolResult = (
558
+ str | BinaryContent | dict [str , Any ] | list [Any ] | Sequence [str | BinaryContent | dict [str , Any ] | list [Any ]]
559
+ )
560
+ """The result type of a tool call."""
561
+
562
+ CallToolFunc = Callable [[str , dict [str , Any ], dict [str , Any ] | None ], Awaitable [ToolResult ]]
563
+ """A function type that represents a tool call."""
564
+
565
+ ProcessToolCallback = Callable [
566
+ [
567
+ RunContext [Any ],
568
+ CallToolFunc ,
569
+ str ,
570
+ dict [str , Any ],
571
+ ],
572
+ Awaitable [ToolResult ],
573
+ ]
574
+ """A process tool callback.
575
+
576
+ It accepts a run context, the original tool call function, a tool name, and arguments.
577
+
578
+ Allows wrapping an MCP server tool call to customize it, including adding extra request
579
+ metadata.
580
+ """
0 commit comments