Skip to content
56 changes: 48 additions & 8 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,23 +548,60 @@ async def handler(req: types.CallToolRequest):

def progress_notification(self):
def decorator(
func: Callable[[str | int, float, float | None, str | None], Awaitable[None]],
func: Callable[
[str | int, float, float | None, str | None, ServerSession | None],
Awaitable[None],
],
):
logger.debug("Registering handler for ProgressNotification")

async def handler(req: types.ProgressNotification):
async def handler(
req: types.ProgressNotification,
session: ServerSession | None = None,
):
await func(
req.params.progressToken,
req.params.progress,
req.params.total,
req.params.message,
session,
)

self.notification_handlers[types.ProgressNotification] = handler
return func

return decorator

def initialized_notification(self):
"""Decorator to register a handler for InitializedNotification."""

def decorator(
func: Callable[
[types.InitializedNotification, ServerSession | None],
Awaitable[None],
],
):
logger.debug("Registering handler for InitializedNotification")
self.notification_handlers[types.InitializedNotification] = func
return func

return decorator

def roots_list_changed_notification(self):
"""Decorator to register a handler for RootsListChangedNotification."""

def decorator(
func: Callable[
[types.RootsListChangedNotification, ServerSession | None],
Awaitable[None],
],
):
logger.debug("Registering handler for RootsListChangedNotification")
self.notification_handlers[types.RootsListChangedNotification] = func
return func

return decorator

def completion(self):
"""Provides completions for prompts and resource templates"""

Expand Down Expand Up @@ -636,7 +673,7 @@ async def run(

async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
message: (RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception),
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
Expand All @@ -648,10 +685,10 @@ async def _handle_message(
with responder:
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
await self._handle_notification(notify, session)

for warning in w:
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
for warning in w:
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)

async def _handle_request(
self,
Expand Down Expand Up @@ -712,12 +749,15 @@ async def _handle_request(

logger.debug("Response sent")

async def _handle_notification(self, notify: Any):
async def _handle_notification(self, notify: Any, session: ServerSession):
if handler := self.notification_handlers.get(type(notify)): # type: ignore
logger.debug("Dispatching notification of type %s", type(notify).__name__)

try:
await handler(notify)
try:
await handler(notify, session)
except TypeError:
await handler(notify)
except Exception:
logger.exception("Uncaught exception in notification handler")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, cast
from unittest.mock import patch

Expand All @@ -12,8 +13,9 @@
from mcp.server.session import ServerSession
from mcp.shared.context import RequestContext
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.shared.message import SessionMessage
from mcp.shared.progress import progress
from mcp.shared.session import BaseSession, RequestResponder, SessionMessage
from mcp.shared.session import BaseSession, RequestResponder


@pytest.mark.anyio
Expand All @@ -23,6 +25,8 @@ async def test_bidirectional_progress_notifications():
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)

server_session_ref: list[ServerSession | None] = [None]

# Run a server session so we can send progress updates in tool
async def run_server():
# Create a server session
Expand All @@ -35,9 +39,7 @@ async def run_server():
capabilities=server.get_capabilities(NotificationOptions(), {}),
),
) as server_session:
global serv_sesh

serv_sesh = server_session
server_session_ref[0] = server_session
async for message in server_session.incoming_messages:
try:
await server._handle_message(message, server_session, {})
Expand All @@ -62,6 +64,7 @@ async def handle_progress(
progress: float,
total: float | None,
message: str | None,
session: ServerSession | None,
):
server_progress_updates.append(
{
Expand All @@ -86,6 +89,10 @@ async def handle_list_tools() -> list[types.Tool]:
# Register tool handler
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]:
serv_sesh = server_session_ref[0]
if not serv_sesh:
raise ValueError("Server session not available")

# Make sure we received a progress token
if name == "test_tool":
if arguments and "_meta" in arguments:
Expand Down Expand Up @@ -228,6 +235,7 @@ async def handle_progress(
progress: float,
total: float | None,
message: str | None,
session: ServerSession | None,
):
server_progress_updates.append(
{"token": progress_token, "progress": progress, "total": total, "message": message}
Expand Down Expand Up @@ -388,3 +396,106 @@ async def handle_list_tools() -> list[types.Tool]:
# Check that a warning was logged for the progress callback exception
assert len(logged_errors) > 0
assert any("Progress callback raised an exception" in warning for warning in logged_errors)


@pytest.mark.anyio
async def test_initialized_notification():
"""Test that the server receives and handles InitializedNotification."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
initialized_received = asyncio.Event()
received_session: ServerSession | None = None

@server.initialized_notification()
async def handle_initialized(
notification: types.InitializedNotification,
session: ServerSession | None = None,
):
nonlocal received_session
received_session = session
initialized_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await initialized_received.wait()
tg.cancel_scope.cancel()

assert initialized_received.is_set()
assert isinstance(received_session, ServerSession)


@pytest.mark.anyio
async def test_roots_list_changed_notification():
"""Test that the server receives and handles RootsListChangedNotification."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
roots_list_changed_received = asyncio.Event()
received_session: ServerSession | None = None

@server.roots_list_changed_notification()
async def handle_roots_list_changed(
notification: types.RootsListChangedNotification,
session: ServerSession | None = None,
):
nonlocal received_session
received_session = session
roots_list_changed_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await client_session.send_notification(
types.ClientNotification(
root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None)
)
)
await roots_list_changed_received.wait()
tg.cancel_scope.cancel()

assert roots_list_changed_received.is_set()
assert isinstance(received_session, ServerSession)
Loading