Skip to content

Commit 80ad862

Browse files
committed
more type and test case updates
1 parent 71a9418 commit 80ad862

File tree

5 files changed

+233
-9
lines changed

5 files changed

+233
-9
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,13 +621,14 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
621621
return self._request_context
622622

623623
async def report_progress(
624-
self, progress: float, total: float | None = None
624+
self, progress: float, total: float | None = None, message: str | None = None
625625
) -> None:
626626
"""Report progress for the current operation.
627627
628628
Args:
629629
progress: Current progress value e.g. 24
630630
total: Optional total value e.g. 100
631+
message: Optional message e.g. Starting render...
631632
"""
632633

633634
progress_token = (
@@ -640,7 +641,10 @@ async def report_progress(
640641
return
641642

642643
await self.request_context.session.send_progress_notification(
643-
progress_token=progress_token, progress=progress, total=total
644+
progress_token=progress_token,
645+
progress=progress,
646+
total=total,
647+
message=message,
644648
)
645649

646650
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:

src/mcp/server/lowlevel/server.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ async def handle_list_resource_templates() -> list[types.ResourceTemplate]:
3737
3. Define notification handlers if needed:
3838
@server.progress_notification()
3939
async def handle_progress(
40-
progress_token: str | int, progress: float, total: float | None
40+
progress_token: str | int, progress: float, total: float | None,
41+
message: str | None
4142
) -> None:
4243
# Implementation
4344
@@ -426,13 +427,18 @@ async def handler(req: types.CallToolRequest):
426427

427428
def progress_notification(self):
428429
def decorator(
429-
func: Callable[[str | int, float, float | None], Awaitable[None]],
430+
func: Callable[
431+
[str | int, float, float | None, str | None], Awaitable[None]
432+
],
430433
):
431434
logger.debug("Registering handler for ProgressNotification")
432435

433436
async def handler(req: types.ProgressNotification):
434437
await func(
435-
req.params.progressToken, req.params.progress, req.params.total
438+
req.params.progressToken,
439+
req.params.progress,
440+
req.params.total,
441+
req.params.message,
436442
)
437443

438444
self.notification_handlers[types.ProgressNotification] = handler

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ class ProgressNotificationParams(NotificationParams):
338338
"""
339339
total: float | None = None
340340
"""
341-
Message related to progress. This should provide relevant human readble
341+
Message related to progress. This should provide relevant human readable
342342
progress information.
343343
"""
344344
message: str | None = None

tests/issues/test_176_progress_token.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call():
3939
mock_session.send_progress_notification.call_count == 3
4040
), "All progress notifications should be sent"
4141
mock_session.send_progress_notification.assert_any_call(
42-
progress_token=0, progress=0.0, total=10.0
42+
progress_token=0, progress=0.0, total=10.0, message=None
4343
)
4444
mock_session.send_progress_notification.assert_any_call(
45-
progress_token=0, progress=5.0, total=10.0
45+
progress_token=0, progress=5.0, total=10.0, message=None
4646
)
4747
mock_session.send_progress_notification.assert_any_call(
48-
progress_token=0, progress=10.0, total=10.0
48+
progress_token=0, progress=10.0, total=10.0, message=None
4949
)
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import anyio
2+
import pytest
3+
4+
import mcp.types as types
5+
from mcp.client.session import ClientSession
6+
from mcp.server import Server
7+
from mcp.server.lowlevel import NotificationOptions
8+
from mcp.server.models import InitializationOptions
9+
from mcp.server.session import ServerSession
10+
from mcp.shared.session import RequestResponder
11+
from mcp.types import (
12+
JSONRPCMessage,
13+
)
14+
15+
16+
@pytest.mark.anyio
17+
async def test_bidirectional_progress_notifications():
18+
"""Test that both client and server can send progress notifications."""
19+
# Create memory streams for client/server
20+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
21+
JSONRPCMessage
22+
](5)
23+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
24+
JSONRPCMessage
25+
](5)
26+
27+
# Run a server session so we can send progress updates in tool
28+
async def run_server():
29+
# Create a server session
30+
async with ServerSession(
31+
client_to_server_receive,
32+
server_to_client_send,
33+
InitializationOptions(
34+
server_name="ProgressTestServer",
35+
server_version="0.1.0",
36+
capabilities=server.get_capabilities(NotificationOptions(), {}),
37+
),
38+
) as server_session:
39+
global serv_sesh
40+
41+
serv_sesh = server_session
42+
async for message in server_session.incoming_messages:
43+
try:
44+
await server._handle_message(message, server_session, ())
45+
except Exception as e:
46+
raise e
47+
48+
# Track progress updates
49+
server_progress_updates = []
50+
client_progress_updates = []
51+
52+
# Progress tokens
53+
server_progress_token = "server_token_123"
54+
client_progress_token = "client_token_456"
55+
56+
# Create a server with progress capability
57+
server = Server(name="ProgressTestServer")
58+
59+
# Register progress handler
60+
@server.progress_notification()
61+
async def handle_progress(
62+
progress_token: str | int,
63+
progress: float,
64+
total: float | None,
65+
message: str | None,
66+
):
67+
server_progress_updates.append(
68+
{
69+
"token": progress_token,
70+
"progress": progress,
71+
"total": total,
72+
"message": message,
73+
}
74+
)
75+
76+
# Register list tool handler
77+
@server.list_tools()
78+
async def handle_list_tools() -> list[types.Tool]:
79+
return [
80+
types.Tool(
81+
name="test_tool",
82+
description="A tool that sends progress notifications <o/",
83+
inputSchema={},
84+
)
85+
]
86+
87+
# Register tool handler
88+
@server.call_tool()
89+
async def handle_call_tool(name: str, arguments: dict | None) -> list:
90+
# Make sure we received a progress token
91+
if name == "test_tool":
92+
if arguments and "_meta" in arguments:
93+
progressToken = arguments["_meta"]["progressToken"]
94+
95+
if not progressToken:
96+
raise ValueError("Empty progress token received")
97+
98+
if progressToken != client_progress_token:
99+
raise ValueError("Server sending back incorrect progressToken")
100+
101+
# Send progress notifications
102+
await serv_sesh.send_progress_notification(
103+
progress_token=progressToken,
104+
progress=0.25,
105+
total=1.0,
106+
message="Server progress 25%",
107+
)
108+
await anyio.sleep(0.2)
109+
110+
await serv_sesh.send_progress_notification(
111+
progress_token=progressToken,
112+
progress=0.5,
113+
total=1.0,
114+
message="Server progress 50%",
115+
)
116+
await anyio.sleep(0.2)
117+
118+
await serv_sesh.send_progress_notification(
119+
progress_token=progressToken,
120+
progress=1.0,
121+
total=1.0,
122+
message="Server progress 100%",
123+
)
124+
125+
else:
126+
raise ValueError("Progress token not sent.")
127+
128+
return ["Tool executed successfully"]
129+
130+
raise ValueError(f"Unknown tool: {name}")
131+
132+
# Client message handler to store progress notifications
133+
async def handle_client_message(
134+
message: RequestResponder[types.ServerRequest, types.ClientResult]
135+
| types.ServerNotification
136+
| Exception,
137+
) -> None:
138+
if isinstance(message, Exception):
139+
raise message
140+
141+
if isinstance(message, types.ServerNotification):
142+
if isinstance(message.root, types.ProgressNotification):
143+
params = message.root.params
144+
client_progress_updates.append(
145+
{
146+
"token": params.progressToken,
147+
"progress": params.progress,
148+
"total": params.total,
149+
"message": params.message,
150+
}
151+
)
152+
153+
# Test using client
154+
async with (
155+
ClientSession(
156+
server_to_client_receive,
157+
client_to_server_send,
158+
message_handler=handle_client_message,
159+
) as client_session,
160+
anyio.create_task_group() as tg,
161+
):
162+
# Start the server in a background task
163+
tg.start_soon(run_server)
164+
165+
# Initialize the client connection
166+
await client_session.initialize()
167+
168+
# Call list_tools with progress token
169+
await client_session.list_tools()
170+
171+
# Call test_tool with progress token
172+
await client_session.call_tool(
173+
"test_tool", {"_meta": {"progressToken": client_progress_token}}
174+
)
175+
176+
# Send progress notifications from client to server
177+
await client_session.send_progress_notification(
178+
progress_token=server_progress_token,
179+
progress=0.33,
180+
total=1.0,
181+
message="Client progress 33%",
182+
)
183+
184+
await client_session.send_progress_notification(
185+
progress_token=server_progress_token,
186+
progress=0.66,
187+
total=1.0,
188+
message="Client progress 66%",
189+
)
190+
191+
await client_session.send_progress_notification(
192+
progress_token=server_progress_token,
193+
progress=1.0,
194+
total=1.0,
195+
message="Client progress 100%",
196+
)
197+
198+
# Wait and exit
199+
await anyio.sleep(1.0)
200+
tg.cancel_scope.cancel()
201+
202+
# Verify client received progress updates from server
203+
assert len(client_progress_updates) == 3
204+
assert client_progress_updates[0]["token"] == client_progress_token
205+
assert client_progress_updates[0]["progress"] == 0.25
206+
assert client_progress_updates[0]["message"] == "Server progress 25%"
207+
assert client_progress_updates[2]["progress"] == 1.0
208+
209+
# Verify server received progress updates from client
210+
assert len(server_progress_updates) == 3
211+
assert server_progress_updates[0]["token"] == server_progress_token
212+
assert server_progress_updates[0]["progress"] == 0.33
213+
assert server_progress_updates[0]["message"] == "Client progress 33%"
214+
assert server_progress_updates[2]["progress"] == 1.0

0 commit comments

Comments
 (0)