Skip to content

Commit f8ca895

Browse files
committed
Support progress in async tools
1 parent 37fb963 commit f8ca895

File tree

3 files changed

+73
-4
lines changed

3 files changed

+73
-4
lines changed

examples/snippets/clients/async_tools_client.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,23 @@ async def demonstrate_batch_processing(session: ClientSession):
9090
print("\n=== Batch Processing Demo ===")
9191

9292
items = ["apple", "banana", "cherry", "date", "elderberry"]
93-
result = await session.call_tool("batch_operation_tool", arguments={"items": items})
93+
94+
# Define progress callback
95+
async def progress_callback(progress: float, total: float | None, message: str | None) -> None:
96+
progress_pct = int(progress * 100) if progress else 0
97+
total_str = f"/{int(total * 100)}%" if total else ""
98+
message_str = f" - {message}" if message else ""
99+
print(f"Progress: {progress_pct}{total_str}{message_str}")
100+
101+
result = await session.call_tool(
102+
"batch_operation_tool", arguments={"items": items}, progress_callback=progress_callback
103+
)
94104

95105
if result.operation:
96106
token = result.operation.token
97107
print(f"Batch operation started with token: {token}")
98108

99-
# Poll for status with progress tracking
109+
# Poll for status
100110
while True:
101111
status = await session.get_operation_status(token)
102112
print(f"Status: {status.status}")

src/mcp/shared/session.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
from mcp.types import (
1717
CONNECTION_CLOSED,
1818
INVALID_PARAMS,
19+
CallToolResult,
1920
CancelledNotification,
2021
ClientNotification,
2122
ClientRequest,
2223
ClientResult,
2324
ErrorData,
25+
GetOperationPayloadRequest,
26+
GetOperationPayloadResult,
2427
JSONRPCError,
2528
JSONRPCMessage,
2629
JSONRPCNotification,
@@ -177,6 +180,7 @@ class BaseSession(
177180
_request_id: int
178181
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
179182
_progress_callbacks: dict[RequestId, ProgressFnT]
183+
_operation_requests: dict[str, RequestId]
180184

181185
def __init__(
182186
self,
@@ -196,6 +200,7 @@ def __init__(
196200
self._session_read_timeout_seconds = read_timeout_seconds
197201
self._in_flight = {}
198202
self._progress_callbacks = {}
203+
self._operation_requests = {}
199204
self._exit_stack = AsyncExitStack()
200205

201206
async def __aenter__(self) -> Self:
@@ -251,6 +256,7 @@ async def send_request(
251256
# Store the callback for this request
252257
self._progress_callbacks[request_id] = progress_callback
253258

259+
pop_progress: RequestId | None = request_id
254260
try:
255261
jsonrpc_request = JSONRPCRequest(
256262
jsonrpc="2.0",
@@ -285,11 +291,28 @@ async def send_request(
285291
if isinstance(response_or_error, JSONRPCError):
286292
raise McpError(response_or_error.error)
287293
else:
288-
return result_type.model_validate(response_or_error.result)
294+
result = result_type.model_validate(response_or_error.result)
295+
if isinstance(result, CallToolResult) and result.operation is not None:
296+
# Store mapping of operation token to request ID for async operations
297+
self._operation_requests[result.operation.token] = request_id
298+
299+
# Don't pop the progress function if we were given one
300+
pop_progress = None
301+
elif isinstance(request, GetOperationPayloadRequest) and isinstance(result, GetOperationPayloadResult):
302+
# Checked request and result to ensure no error
303+
operation_token = request.params.token
304+
305+
# Pop the progress function for the original request
306+
pop_progress = self._operation_requests[operation_token]
307+
308+
# Pop the token mapping since we know we won't need it anymore
309+
self._operation_requests.pop(operation_token, None)
310+
return result
289311

290312
finally:
291313
self._response_streams.pop(request_id, None)
292-
self._progress_callbacks.pop(request_id, None)
314+
if pop_progress:
315+
self._progress_callbacks.pop(pop_progress, None)
293316
await response_stream.aclose()
294317
await response_stream_reader.aclose()
295318

tests/server/fastmcp/test_integration.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,42 @@ async def test_async_tools(server_transport: str, server_url: str) -> None:
760760
with pytest.raises(Exception): # Should raise error when trying to access expired operation
761761
await session.get_operation_result(quick_token)
762762

763+
# Test batch operation with progress notifications
764+
progress_received = False
765+
766+
async def progress_callback(progress: float, total: float | None, message: str | None) -> None:
767+
nonlocal progress_received
768+
progress_received = True
769+
assert 0.0 <= progress <= 1.0 # Progress should be between 0 and 1
770+
771+
batch_result = await session.call_tool(
772+
"batch_operation_tool",
773+
{"items": ["apple", "banana", "cherry"]},
774+
progress_callback=progress_callback,
775+
)
776+
assert batch_result.operation is not None
777+
batch_token = batch_result.operation.token
778+
779+
while True:
780+
status = await session.get_operation_status(batch_token)
781+
782+
if status.status == "completed":
783+
final_result = await session.get_operation_result(batch_token)
784+
assert not final_result.result.isError
785+
# Should have structured content with processed items
786+
if final_result.result.structuredContent:
787+
# Structured content is wrapped in {"result": [...]} for list return types
788+
assert isinstance(final_result.result.structuredContent, dict)
789+
assert "result" in final_result.result.structuredContent
790+
assert isinstance(final_result.result.structuredContent["result"], list)
791+
assert len(final_result.result.structuredContent["result"]) == 3
792+
break
793+
elif status.status == "failed":
794+
pytest.fail(f"Batch operation failed: {status.error}")
795+
796+
# Assert that we received at least one progress notification
797+
assert progress_received, "Should have received progress notifications during batch operation"
798+
763799

764800
# Test async tools example with legacy protocol
765801
@pytest.mark.anyio

0 commit comments

Comments
 (0)