Skip to content

Commit 9da2d58

Browse files
committed
add support for supplying a progress callback to call_tool requests
1 parent 58c5e72 commit 9da2d58

File tree

2 files changed

+153
-10
lines changed

2 files changed

+153
-10
lines changed

src/mcp/client/session.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ async def __call__(
3535
) -> None: ...
3636

3737

38+
class ProgressFnT(Protocol):
39+
async def __call__(
40+
self,
41+
params: types.ProgressNotificationParams,
42+
) -> None: ...
43+
44+
3845
class MessageHandlerFnT(Protocol):
3946
async def __call__(
4047
self,
@@ -91,6 +98,9 @@ class ClientSession(
9198
types.ServerNotification,
9299
]
93100
):
101+
_progress_id: int
102+
_in_progress: dict[types.ProgressToken, ProgressFnT]
103+
94104
def __init__(
95105
self,
96106
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -114,6 +124,8 @@ def __init__(
114124
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
115125
self._logging_callback = logging_callback or _default_logging_callback
116126
self._message_handler = message_handler or _default_message_handler
127+
self._progress_id = 0
128+
self._in_progress = {}
117129

118130
async def initialize(self) -> types.InitializeResult:
119131
sampling = types.SamplingCapability()
@@ -259,19 +271,37 @@ async def call_tool(
259271
name: str,
260272
arguments: dict[str, Any] | None = None,
261273
read_timeout_seconds: timedelta | None = None,
274+
progress_callback: ProgressFnT | None = None,
262275
) -> types.CallToolResult:
263276
"""Send a tools/call request."""
264277

265-
return await self.send_request(
266-
types.ClientRequest(
267-
types.CallToolRequest(
268-
method="tools/call",
269-
params=types.CallToolRequestParams(name=name, arguments=arguments),
270-
)
271-
),
272-
types.CallToolResult,
273-
request_read_timeout_seconds=read_timeout_seconds,
274-
)
278+
if progress_callback is None:
279+
progress_id = None
280+
call_params = types.CallToolRequestParams(name=name, arguments=arguments)
281+
else:
282+
progress_id = self._progress_id
283+
self._progress_id = progress_id + 1
284+
285+
call_meta = types.RequestParams.Meta(progressToken=progress_id)
286+
call_params = types.CallToolRequestParams(
287+
name=name, arguments=arguments, _meta=call_meta
288+
)
289+
self._in_progress[progress_id] = progress_callback
290+
291+
try:
292+
return await self.send_request(
293+
types.ClientRequest(
294+
types.CallToolRequest(
295+
method="tools/call",
296+
params=call_params,
297+
)
298+
),
299+
types.CallToolResult,
300+
request_read_timeout_seconds=read_timeout_seconds,
301+
)
302+
finally:
303+
if progress_id is not None:
304+
self._in_progress.pop(progress_id, None)
275305

276306
async def list_prompts(self) -> types.ListPromptsResult:
277307
"""Send a prompts/list request."""
@@ -384,5 +414,9 @@ async def _received_notification(
384414
match notification.root:
385415
case types.LoggingMessageNotification(params=params):
386416
await self._logging_callback(params)
417+
case types.ProgressNotification(params=params):
418+
if params.progressToken in self._in_progress:
419+
progress_callback = self._in_progress[params.progressToken]
420+
await progress_callback(params)
387421
case _:
388422
pass

tests/client/test_session.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,112 @@ async def mock_server():
250250

251251
# Assert that the default client info was sent
252252
assert received_client_info == DEFAULT_CLIENT_INFO
253+
254+
255+
@pytest.mark.anyio
256+
async def test_client_session_progress():
257+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
258+
SessionMessage
259+
](1)
260+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
261+
SessionMessage
262+
](1)
263+
264+
async def mock_server():
265+
session_message = await client_to_server_receive.receive()
266+
jsonrpc_request = session_message.message
267+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
268+
request = ClientRequest.model_validate(
269+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
270+
)
271+
assert isinstance(request.root, types.CallToolRequest)
272+
assert request.root.params.meta
273+
assert request.root.params.meta.progressToken is not None
274+
275+
progress_token = request.root.params.meta.progressToken
276+
277+
notifications = [
278+
types.ServerNotification(
279+
root=types.ProgressNotification(
280+
params=types.ProgressNotificationParams(
281+
progressToken=progress_token, progress=1
282+
),
283+
method="notifications/progress",
284+
)
285+
),
286+
types.ServerNotification(
287+
root=types.ProgressNotification(
288+
params=types.ProgressNotificationParams(
289+
progressToken=progress_token, progress=2
290+
),
291+
method="notifications/progress",
292+
)
293+
),
294+
]
295+
result = ServerResult(types.CallToolResult(content=[]))
296+
297+
async with server_to_client_send:
298+
for notification in notifications:
299+
await server_to_client_send.send(
300+
SessionMessage(
301+
JSONRPCMessage(
302+
types.JSONRPCNotification(
303+
jsonrpc="2.0",
304+
**notification.model_dump(
305+
by_alias=True, mode="json", exclude_none=True
306+
),
307+
)
308+
)
309+
)
310+
)
311+
await server_to_client_send.send(
312+
SessionMessage(
313+
JSONRPCMessage(
314+
JSONRPCResponse(
315+
jsonrpc="2.0",
316+
id=jsonrpc_request.root.id,
317+
result=result.model_dump(
318+
by_alias=True, mode="json", exclude_none=True
319+
),
320+
)
321+
)
322+
)
323+
)
324+
325+
# Create a message handler to catch exceptions
326+
async def message_handler(
327+
message: RequestResponder[types.ServerRequest, types.ClientResult]
328+
| types.ServerNotification
329+
| Exception,
330+
) -> None:
331+
if isinstance(message, Exception):
332+
raise message
333+
334+
progress_count = 0
335+
336+
async with (
337+
ClientSession(
338+
server_to_client_receive,
339+
client_to_server_send,
340+
message_handler=message_handler,
341+
) as session,
342+
anyio.create_task_group() as tg,
343+
client_to_server_send,
344+
client_to_server_receive,
345+
server_to_client_send,
346+
server_to_client_receive,
347+
):
348+
tg.start_soon(mock_server)
349+
350+
async def progress_callback(params: types.ProgressNotificationParams):
351+
nonlocal progress_count
352+
progress_count = progress_count + 1
353+
354+
result = await session.call_tool(
355+
"tool_with_progress", progress_callback=progress_callback
356+
)
357+
358+
# Assert the result
359+
assert isinstance(result, types.CallToolResult)
360+
assert len(result.content) == 0
361+
assert progress_count == 2

0 commit comments

Comments
 (0)