Skip to content
62 changes: 52 additions & 10 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ async def main():
from __future__ import annotations as _annotations

import contextvars
import inspect
import json
import logging
import warnings
Expand Down Expand Up @@ -104,6 +105,9 @@ async def main():
# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")

# Context variable to hold the current ServerSession, accessible by notification handlers
current_session_ctx: contextvars.ContextVar[ServerSession] = contextvars.ContextVar("current_server_session")


class NotificationOptions:
def __init__(
Expand Down Expand Up @@ -525,6 +529,36 @@ async def handler(req: types.ProgressNotification):

return decorator

def initialized_notification(self):
"""Decorator to register a handler for InitializedNotification."""

def decorator(
func: (
Callable[[types.InitializedNotification, ServerSession], Awaitable[None]]
| Callable[[types.InitializedNotification], Awaitable[None]]
),
):
logger.debug("Registering handler for InitializedNotification")
self.notification_handlers[types.InitializedNotification] = func
return func

return decorator

def roots_list_changed_notification(self):
"""Decorator to register a handler for RootsListChangedNotification."""

def decorator(
func: (
Callable[[types.RootsListChangedNotification, ServerSession], Awaitable[None]]
| Callable[[types.RootsListChangedNotification], Awaitable[None]]
),
):
logger.debug("Registering handler for RootsListChangedNotification")
self.notification_handlers[types.RootsListChangedNotification] = func
return func

return decorator

def completion(self):
"""Provides completions for prompts and resource templates"""

Expand Down Expand Up @@ -596,22 +630,26 @@ async def run(

async def _handle_message(
self,
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
message: (RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception),
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
# TODO(Marcelo): We should be checking if message is Exception here.
match message: # type: ignore[reportMatchNotExhaustive]
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
with responder:
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
session_token = current_session_ctx.set(session)
try:
with warnings.catch_warnings(record=True) as w:
# TODO(Marcelo): We should be checking if message is Exception here.
match message: # type: ignore[reportMatchNotExhaustive]
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
with responder:
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)

for warning in w:
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
finally:
current_session_ctx.reset(session_token)

async def _handle_request(
self,
Expand Down Expand Up @@ -677,7 +715,11 @@ async def _handle_notification(self, notify: Any):
logger.debug("Dispatching notification of type %s", type(notify).__name__)

try:
await handler(notify)
sig = inspect.signature(handler)
if "session" in sig.parameters:
await handler(notify, current_session_ctx.get())
else:
await handler(notify)
except Exception:
logger.exception("Uncaught exception in notification handler")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, cast

import anyio
Expand All @@ -10,6 +11,7 @@
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.progress import progress
from mcp.shared.session import BaseSession, RequestResponder, SessionMessage

Expand Down Expand Up @@ -320,3 +322,191 @@ async def handle_client_message(
assert server_progress_updates[3]["progress"] == 100
assert server_progress_updates[3]["total"] == 100
assert server_progress_updates[3]["message"] == "Processing results..."


@pytest.mark.anyio
async def test_initialized_notification():
"""Test that the server receives and handles InitializedNotification."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
initialized_received = asyncio.Event()

@server.initialized_notification()
async def handle_initialized(notification: types.InitializedNotification):
initialized_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await initialized_received.wait()
tg.cancel_scope.cancel()

assert initialized_received.is_set()


@pytest.mark.anyio
async def test_roots_list_changed_notification():
"""Test that the server receives and handles RootsListChangedNotification."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
roots_list_changed_received = asyncio.Event()

@server.roots_list_changed_notification()
async def handle_roots_list_changed(
notification: types.RootsListChangedNotification,
):
roots_list_changed_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await client_session.send_notification(
types.ClientNotification(
root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None)
)
)
await roots_list_changed_received.wait()
tg.cancel_scope.cancel()

assert roots_list_changed_received.is_set()


@pytest.mark.anyio
async def test_initialized_notification_with_session():
"""Test that the server receives and handles InitializedNotification with a session."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
initialized_received = asyncio.Event()
received_session = None

@server.initialized_notification()
async def handle_initialized(notification: types.InitializedNotification, session: ServerSession):
nonlocal received_session
received_session = session
initialized_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await initialized_received.wait()
tg.cancel_scope.cancel()

assert initialized_received.is_set()
assert isinstance(received_session, ServerSession)


@pytest.mark.anyio
async def test_roots_list_changed_notification_with_session():
"""Test that the server receives and handles RootsListChangedNotification with a session."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)

server = Server("test")
roots_list_changed_received = asyncio.Event()
received_session = None

@server.roots_list_changed_notification()
async def handle_roots_list_changed(notification: types.RootsListChangedNotification, session: ServerSession):
nonlocal received_session
received_session = session
roots_list_changed_received.set()

async def run_server():
await server.run(
client_to_server_receive,
server_to_client_send,
server.create_initialization_options(),
)

async def message_handler(
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
) -> None:
if isinstance(message, Exception):
raise message

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as client_session,
anyio.create_task_group() as tg,
):
tg.start_soon(run_server)
await client_session.initialize()
await client_session.send_notification(
types.ClientNotification(
root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None)
)
)
await roots_list_changed_received.wait()
tg.cancel_scope.cancel()

assert roots_list_changed_received.is_set()
assert isinstance(received_session, ServerSession)
Loading